emdash-core 0.1.33__py3-none-any.whl → 0.1.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emdash_core/agent/agents.py +93 -23
- emdash_core/agent/background.py +481 -0
- emdash_core/agent/hooks.py +419 -0
- emdash_core/agent/inprocess_subagent.py +114 -10
- emdash_core/agent/mcp/config.py +78 -2
- emdash_core/agent/prompts/main_agent.py +88 -1
- emdash_core/agent/prompts/plan_mode.py +65 -44
- emdash_core/agent/prompts/subagents.py +96 -8
- emdash_core/agent/prompts/workflow.py +215 -50
- emdash_core/agent/providers/models.py +1 -1
- emdash_core/agent/providers/openai_provider.py +10 -0
- emdash_core/agent/research/researcher.py +154 -45
- emdash_core/agent/runner/agent_runner.py +157 -19
- emdash_core/agent/runner/context.py +28 -9
- emdash_core/agent/runner/sdk_runner.py +29 -2
- emdash_core/agent/skills.py +81 -1
- emdash_core/agent/toolkit.py +87 -11
- emdash_core/agent/toolkits/__init__.py +117 -18
- emdash_core/agent/toolkits/base.py +87 -2
- emdash_core/agent/toolkits/explore.py +18 -0
- emdash_core/agent/toolkits/plan.py +18 -0
- emdash_core/agent/tools/__init__.py +2 -0
- emdash_core/agent/tools/coding.py +344 -52
- emdash_core/agent/tools/lsp.py +361 -0
- emdash_core/agent/tools/skill.py +21 -1
- emdash_core/agent/tools/task.py +27 -23
- emdash_core/agent/tools/task_output.py +262 -32
- emdash_core/agent/verifier/__init__.py +11 -0
- emdash_core/agent/verifier/manager.py +295 -0
- emdash_core/agent/verifier/models.py +97 -0
- emdash_core/{swarm/worktree_manager.py → agent/worktree.py} +19 -1
- emdash_core/api/agent.py +451 -5
- emdash_core/api/research.py +3 -3
- emdash_core/api/router.py +0 -4
- emdash_core/context/longevity.py +197 -0
- emdash_core/context/providers/explored_areas.py +83 -39
- emdash_core/context/reranker.py +35 -144
- emdash_core/context/simple_reranker.py +500 -0
- emdash_core/context/tool_relevance.py +84 -0
- emdash_core/core/config.py +8 -0
- emdash_core/graph/__init__.py +8 -1
- emdash_core/graph/connection.py +24 -3
- emdash_core/graph/writer.py +7 -1
- emdash_core/ingestion/repository.py +17 -198
- emdash_core/models/agent.py +14 -0
- emdash_core/server.py +1 -6
- emdash_core/sse/stream.py +16 -1
- emdash_core/utils/__init__.py +0 -2
- emdash_core/utils/git.py +103 -0
- emdash_core/utils/image.py +147 -160
- {emdash_core-0.1.33.dist-info → emdash_core-0.1.60.dist-info}/METADATA +7 -5
- {emdash_core-0.1.33.dist-info → emdash_core-0.1.60.dist-info}/RECORD +54 -58
- emdash_core/api/swarm.py +0 -223
- emdash_core/db/__init__.py +0 -67
- emdash_core/db/auth.py +0 -134
- emdash_core/db/models.py +0 -91
- emdash_core/db/provider.py +0 -222
- emdash_core/db/providers/__init__.py +0 -5
- emdash_core/db/providers/supabase.py +0 -452
- emdash_core/swarm/__init__.py +0 -17
- emdash_core/swarm/merge_agent.py +0 -383
- emdash_core/swarm/session_manager.py +0 -274
- emdash_core/swarm/swarm_runner.py +0 -226
- emdash_core/swarm/task_definition.py +0 -137
- emdash_core/swarm/worker_spawner.py +0 -319
- {emdash_core-0.1.33.dist-info → emdash_core-0.1.60.dist-info}/WHEEL +0 -0
- {emdash_core-0.1.33.dist-info → emdash_core-0.1.60.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
"""Longevity tracking for context items.
|
|
2
|
+
|
|
3
|
+
Tracks which entities appear repeatedly across reranking calls.
|
|
4
|
+
Items that keep appearing are likely important and get boosted.
|
|
5
|
+
|
|
6
|
+
This uses an in-memory cache that resets on process restart.
|
|
7
|
+
For persistence, the cache could be stored in the graph database.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import math
|
|
11
|
+
import time
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from typing import Optional
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class LongevityRecord:
|
|
18
|
+
"""Track an entity's appearance history."""
|
|
19
|
+
|
|
20
|
+
qualified_name: str
|
|
21
|
+
appearance_count: int = 0
|
|
22
|
+
first_seen: float = field(default_factory=time.time)
|
|
23
|
+
last_seen: float = field(default_factory=time.time)
|
|
24
|
+
|
|
25
|
+
def record_appearance(self) -> None:
|
|
26
|
+
"""Record a new appearance of this entity."""
|
|
27
|
+
self.appearance_count += 1
|
|
28
|
+
self.last_seen = time.time()
|
|
29
|
+
|
|
30
|
+
def get_longevity_score(self, now: Optional[float] = None) -> float:
|
|
31
|
+
"""Calculate longevity score based on appearance count.
|
|
32
|
+
|
|
33
|
+
Longevity = items that have appeared in context frame more than once.
|
|
34
|
+
No time-based decay - if it keeps appearing, it's important.
|
|
35
|
+
|
|
36
|
+
Score formula (log scale for diminishing returns):
|
|
37
|
+
- 1 appearance = 0.0 (first time, no longevity yet)
|
|
38
|
+
- 2 appearances = 0.37
|
|
39
|
+
- 3 appearances = 0.50
|
|
40
|
+
- 5 appearances = 0.62
|
|
41
|
+
- 10 appearances = 0.77
|
|
42
|
+
- 20 appearances = 0.90
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
now: Current timestamp (unused, kept for API compatibility)
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Score between 0.0 and 1.0
|
|
49
|
+
"""
|
|
50
|
+
if self.appearance_count <= 1:
|
|
51
|
+
return 0.0 # First appearance = no longevity
|
|
52
|
+
|
|
53
|
+
# Log scale for diminishing returns
|
|
54
|
+
# Subtract 1 so first repeat (count=2) starts contributing
|
|
55
|
+
return min(1.0, math.log(self.appearance_count) / 3)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class LongevityTracker:
|
|
59
|
+
"""Tracks entity appearances across reranking calls."""
|
|
60
|
+
|
|
61
|
+
def __init__(self, max_entries: int = 1000):
|
|
62
|
+
"""Initialize the tracker.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
max_entries: Maximum number of entities to track (LRU eviction)
|
|
66
|
+
"""
|
|
67
|
+
self._records: dict[str, LongevityRecord] = {}
|
|
68
|
+
self._max_entries = max_entries
|
|
69
|
+
|
|
70
|
+
def record_appearance(self, qualified_name: str) -> None:
|
|
71
|
+
"""Record that an entity appeared in reranking.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
qualified_name: The entity's qualified name
|
|
75
|
+
"""
|
|
76
|
+
if qualified_name in self._records:
|
|
77
|
+
self._records[qualified_name].record_appearance()
|
|
78
|
+
else:
|
|
79
|
+
# Evict oldest entries if at capacity
|
|
80
|
+
if len(self._records) >= self._max_entries:
|
|
81
|
+
self._evict_oldest()
|
|
82
|
+
|
|
83
|
+
self._records[qualified_name] = LongevityRecord(
|
|
84
|
+
qualified_name=qualified_name,
|
|
85
|
+
appearance_count=1,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def record_batch(self, qualified_names: list[str]) -> None:
|
|
89
|
+
"""Record appearances for multiple entities.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
qualified_names: List of entity qualified names
|
|
93
|
+
"""
|
|
94
|
+
for qname in qualified_names:
|
|
95
|
+
self.record_appearance(qname)
|
|
96
|
+
|
|
97
|
+
def get_longevity_score(self, qualified_name: str) -> float:
|
|
98
|
+
"""Get the longevity score for an entity.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
qualified_name: The entity's qualified name
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Score between 0.0 and 1.0 (0.0 if never seen)
|
|
105
|
+
"""
|
|
106
|
+
record = self._records.get(qualified_name)
|
|
107
|
+
if record is None:
|
|
108
|
+
return 0.0
|
|
109
|
+
return record.get_longevity_score()
|
|
110
|
+
|
|
111
|
+
def get_appearance_count(self, qualified_name: str) -> int:
|
|
112
|
+
"""Get how many times an entity has appeared.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
qualified_name: The entity's qualified name
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Number of appearances (0 if never seen)
|
|
119
|
+
"""
|
|
120
|
+
record = self._records.get(qualified_name)
|
|
121
|
+
return record.appearance_count if record else 0
|
|
122
|
+
|
|
123
|
+
def _evict_oldest(self) -> None:
|
|
124
|
+
"""Evict the oldest (least recently seen) entries."""
|
|
125
|
+
if not self._records:
|
|
126
|
+
return
|
|
127
|
+
|
|
128
|
+
# Sort by last_seen and remove bottom 10%
|
|
129
|
+
sorted_records = sorted(
|
|
130
|
+
self._records.items(),
|
|
131
|
+
key=lambda x: x[1].last_seen,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
evict_count = max(1, len(sorted_records) // 10)
|
|
135
|
+
for qname, _ in sorted_records[:evict_count]:
|
|
136
|
+
del self._records[qname]
|
|
137
|
+
|
|
138
|
+
def clear(self) -> None:
|
|
139
|
+
"""Clear all longevity records."""
|
|
140
|
+
self._records.clear()
|
|
141
|
+
|
|
142
|
+
def get_stats(self) -> dict:
|
|
143
|
+
"""Get statistics about the tracker.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Dictionary with tracker statistics
|
|
147
|
+
"""
|
|
148
|
+
if not self._records:
|
|
149
|
+
return {
|
|
150
|
+
"total_entities": 0,
|
|
151
|
+
"total_appearances": 0,
|
|
152
|
+
"avg_appearances": 0,
|
|
153
|
+
"max_appearances": 0,
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
appearances = [r.appearance_count for r in self._records.values()]
|
|
157
|
+
return {
|
|
158
|
+
"total_entities": len(self._records),
|
|
159
|
+
"total_appearances": sum(appearances),
|
|
160
|
+
"avg_appearances": sum(appearances) / len(appearances),
|
|
161
|
+
"max_appearances": max(appearances),
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
# Global tracker instance (shared across reranking calls)
|
|
166
|
+
_global_tracker: Optional[LongevityTracker] = None
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def get_longevity_tracker() -> LongevityTracker:
|
|
170
|
+
"""Get the global longevity tracker (creates if needed)."""
|
|
171
|
+
global _global_tracker
|
|
172
|
+
if _global_tracker is None:
|
|
173
|
+
_global_tracker = LongevityTracker()
|
|
174
|
+
return _global_tracker
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def record_reranked_items(qualified_names: list[str]) -> None:
|
|
178
|
+
"""Record that items appeared in a reranking result.
|
|
179
|
+
|
|
180
|
+
Call this after reranking to update longevity scores.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
qualified_names: List of qualified names that were reranked
|
|
184
|
+
"""
|
|
185
|
+
get_longevity_tracker().record_batch(qualified_names)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def get_longevity_score(qualified_name: str) -> float:
|
|
189
|
+
"""Get the longevity score for an entity.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
qualified_name: The entity's qualified name
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
Score between 0.0 and 1.0
|
|
196
|
+
"""
|
|
197
|
+
return get_longevity_tracker().get_longevity_score(qualified_name)
|
|
@@ -4,6 +4,14 @@ from dataclasses import asdict
|
|
|
4
4
|
from typing import Optional, Union
|
|
5
5
|
|
|
6
6
|
from ..models import ContextItem, ContextProviderSpec
|
|
7
|
+
from ..tool_relevance import (
|
|
8
|
+
TOOL_RELEVANCE,
|
|
9
|
+
SEARCH_TOOLS,
|
|
10
|
+
TOP_RESULTS_LIMIT,
|
|
11
|
+
NON_TOP_RESULT_MULTIPLIER,
|
|
12
|
+
get_tool_relevance,
|
|
13
|
+
is_search_tool,
|
|
14
|
+
)
|
|
7
15
|
from .base import ContextProvider
|
|
8
16
|
from ..registry import ContextProviderRegistry
|
|
9
17
|
from ...graph.connection import KuzuConnection
|
|
@@ -16,44 +24,16 @@ class ExploredAreasProvider(ContextProvider):
|
|
|
16
24
|
Analyzes the steps recorded during an agent session and assigns
|
|
17
25
|
relevance scores based on the tool type used to discover each entity.
|
|
18
26
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
27
|
+
Scoring is defined in tool_relevance.py:
|
|
28
|
+
- Highest: Code modifications (write_to_file, apply_diff)
|
|
29
|
+
- High: Deliberate investigation (expand_node, get_callers, read_file)
|
|
30
|
+
- Medium: Targeted search (semantic_search, text_search, grep)
|
|
31
|
+
- Low: Broad discovery (list_files, graph algorithms)
|
|
22
32
|
"""
|
|
23
33
|
|
|
24
|
-
# Tool-based relevance scores
|
|
25
|
-
TOOL_RELEVANCE = {
|
|
26
|
-
# High relevance - deliberate investigation
|
|
27
|
-
"expand_node": 1.0,
|
|
28
|
-
"get_callers": 0.9,
|
|
29
|
-
"get_callees": 0.9,
|
|
30
|
-
"get_class_hierarchy": 0.9,
|
|
31
|
-
"get_neighbors": 0.85,
|
|
32
|
-
"get_impact_analysis": 0.85,
|
|
33
|
-
"read_file": 0.8, # Reading a file is deliberate investigation
|
|
34
|
-
# Medium relevance - targeted search
|
|
35
|
-
"semantic_search": 0.7,
|
|
36
|
-
"text_search": 0.6,
|
|
37
|
-
"get_file_dependencies": 0.6,
|
|
38
|
-
"find_entity": 0.6,
|
|
39
|
-
# Lower relevance - broad search/modification
|
|
40
|
-
"grep": 0.4,
|
|
41
|
-
"write_to_file": 0.4,
|
|
42
|
-
"apply_diff": 0.4,
|
|
43
|
-
"get_top_pagerank": 0.3,
|
|
44
|
-
"get_communities": 0.3,
|
|
45
|
-
"list_files": 0.2,
|
|
46
|
-
"execute_command": 0.1,
|
|
47
|
-
}
|
|
48
|
-
|
|
49
|
-
# Only top N results from search tools are considered highly relevant
|
|
50
|
-
TOP_RESULTS_LIMIT = 3
|
|
51
|
-
|
|
52
|
-
# Tools where we limit to top results
|
|
53
|
-
SEARCH_TOOLS = {"semantic_search", "text_search", "grep", "find_entity"}
|
|
54
|
-
|
|
55
34
|
def __init__(self, connection: KuzuConnection, config: Optional[dict] = None):
|
|
56
35
|
super().__init__(connection, config)
|
|
36
|
+
self._neighbor_cache: dict[str, list[str]] = {}
|
|
57
37
|
|
|
58
38
|
@property
|
|
59
39
|
def spec(self) -> ContextProviderSpec:
|
|
@@ -88,10 +68,10 @@ class ExploredAreasProvider(ContextProvider):
|
|
|
88
68
|
entities = step.get("entities_discovered", [])
|
|
89
69
|
|
|
90
70
|
# Get base relevance score for this tool
|
|
91
|
-
base_score =
|
|
71
|
+
base_score = get_tool_relevance(tool_name)
|
|
92
72
|
|
|
93
73
|
# For search tools, only top results are highly relevant
|
|
94
|
-
if tool_name
|
|
74
|
+
if is_search_tool(tool_name):
|
|
95
75
|
# Process top results with full score, others with reduced score
|
|
96
76
|
for i, entity in enumerate(entities):
|
|
97
77
|
qname = self._extract_qualified_name(entity)
|
|
@@ -99,10 +79,10 @@ class ExploredAreasProvider(ContextProvider):
|
|
|
99
79
|
continue
|
|
100
80
|
|
|
101
81
|
# Top results get full score, others get reduced
|
|
102
|
-
if i <
|
|
82
|
+
if i < TOP_RESULTS_LIMIT:
|
|
103
83
|
score = base_score
|
|
104
84
|
else:
|
|
105
|
-
score = base_score *
|
|
85
|
+
score = base_score * NON_TOP_RESULT_MULTIPLIER
|
|
106
86
|
|
|
107
87
|
self._update_entity_score(entity_scores, qname, score, entity)
|
|
108
88
|
else:
|
|
@@ -120,13 +100,17 @@ class ExploredAreasProvider(ContextProvider):
|
|
|
120
100
|
display_name = qname
|
|
121
101
|
if qname.startswith("file:"):
|
|
122
102
|
display_name = qname[5:] # Remove "file:" prefix
|
|
103
|
+
|
|
104
|
+
# Fetch neighbors from graph
|
|
105
|
+
neighbors = self._fetch_neighbors(display_name, entity_type)
|
|
106
|
+
|
|
123
107
|
items.append(
|
|
124
108
|
ContextItem(
|
|
125
109
|
qualified_name=display_name,
|
|
126
110
|
entity_type=entity_type or "Unknown",
|
|
127
111
|
file_path=file_path,
|
|
128
112
|
score=score,
|
|
129
|
-
neighbors=
|
|
113
|
+
neighbors=neighbors,
|
|
130
114
|
)
|
|
131
115
|
)
|
|
132
116
|
|
|
@@ -178,6 +162,66 @@ class ExploredAreasProvider(ContextProvider):
|
|
|
178
162
|
return entity.get("file_path") or entity.get("path")
|
|
179
163
|
return None
|
|
180
164
|
|
|
165
|
+
def _fetch_neighbors(
|
|
166
|
+
self, qualified_name: str, entity_type: Optional[str], limit: int = 5
|
|
167
|
+
) -> list[str]:
|
|
168
|
+
"""Fetch neighbors (callers/callees) from the graph.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
qualified_name: The entity's qualified name
|
|
172
|
+
entity_type: The entity type (Function, Class, File)
|
|
173
|
+
limit: Maximum number of neighbors to return
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
List of neighbor qualified names
|
|
177
|
+
"""
|
|
178
|
+
# Check cache first
|
|
179
|
+
if qualified_name in self._neighbor_cache:
|
|
180
|
+
return self._neighbor_cache[qualified_name]
|
|
181
|
+
|
|
182
|
+
# Files don't have caller/callee relationships in the same way
|
|
183
|
+
if entity_type == "File" or not self.connection:
|
|
184
|
+
self._neighbor_cache[qualified_name] = []
|
|
185
|
+
return []
|
|
186
|
+
|
|
187
|
+
neighbors = []
|
|
188
|
+
try:
|
|
189
|
+
conn = self.connection.connect()
|
|
190
|
+
|
|
191
|
+
# Query for callers and callees
|
|
192
|
+
if entity_type in ("Function", "Class"):
|
|
193
|
+
# Get callees (what this entity calls)
|
|
194
|
+
callees_query = f"""
|
|
195
|
+
MATCH (n:{entity_type} {{qualified_name: $qname}})-[:CALLS]->(m)
|
|
196
|
+
RETURN m.qualified_name
|
|
197
|
+
LIMIT $limit
|
|
198
|
+
"""
|
|
199
|
+
result = conn.execute(callees_query, {"qname": qualified_name, "limit": limit})
|
|
200
|
+
while result.has_next():
|
|
201
|
+
row = result.get_next()
|
|
202
|
+
if row[0]:
|
|
203
|
+
neighbors.append(row[0])
|
|
204
|
+
|
|
205
|
+
# Get callers (what calls this entity)
|
|
206
|
+
remaining = limit - len(neighbors)
|
|
207
|
+
if remaining > 0:
|
|
208
|
+
callers_query = f"""
|
|
209
|
+
MATCH (n)-[:CALLS]->(m:{entity_type} {{qualified_name: $qname}})
|
|
210
|
+
RETURN n.qualified_name
|
|
211
|
+
LIMIT $limit
|
|
212
|
+
"""
|
|
213
|
+
result = conn.execute(callers_query, {"qname": qualified_name, "limit": remaining})
|
|
214
|
+
while result.has_next():
|
|
215
|
+
row = result.get_next()
|
|
216
|
+
if row[0] and row[0] not in neighbors:
|
|
217
|
+
neighbors.append(row[0])
|
|
218
|
+
|
|
219
|
+
except Exception as e:
|
|
220
|
+
log.debug(f"Failed to fetch neighbors for {qualified_name}: {e}")
|
|
221
|
+
|
|
222
|
+
self._neighbor_cache[qualified_name] = neighbors
|
|
223
|
+
return neighbors
|
|
224
|
+
|
|
181
225
|
|
|
182
226
|
# Auto-register provider
|
|
183
227
|
ContextProviderRegistry.register("explored_areas", ExploredAreasProvider)
|
emdash_core/context/reranker.py
CHANGED
|
@@ -1,108 +1,42 @@
|
|
|
1
1
|
"""Re-ranker for filtering context items by query relevance.
|
|
2
2
|
|
|
3
|
-
Uses a
|
|
4
|
-
|
|
3
|
+
Uses a lightweight scoring system based on:
|
|
4
|
+
1. Text matching (query terms vs entity names/paths/descriptions)
|
|
5
|
+
2. Graph signals (pagerank, betweenness centrality)
|
|
6
|
+
3. Session signals (recency, touch frequency)
|
|
7
|
+
4. Longevity signals (items that keep appearing are important)
|
|
8
|
+
5. File co-occurrence (files with multiple entities get boosted)
|
|
9
|
+
|
|
10
|
+
This reranker requires zero external ML dependencies and runs in <10ms.
|
|
5
11
|
"""
|
|
6
12
|
|
|
7
13
|
import os
|
|
8
14
|
from typing import Optional
|
|
9
15
|
|
|
10
|
-
# Disable tokenizers parallelism to avoid fork warnings when running in threads
|
|
11
|
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
12
|
-
|
|
13
16
|
from .models import ContextItem
|
|
17
|
+
from .simple_reranker import simple_rerank_items, get_simple_rerank_scores
|
|
14
18
|
from ..utils.logger import log
|
|
15
19
|
|
|
16
|
-
# Model singleton to avoid reloading
|
|
17
|
-
_reranker_model = None
|
|
18
|
-
_model_load_attempted = False
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def get_reranker_model():
|
|
22
|
-
"""Get or load the re-ranker model (singleton).
|
|
23
|
-
|
|
24
|
-
Returns:
|
|
25
|
-
CrossEncoder model or None if not available
|
|
26
|
-
"""
|
|
27
|
-
global _reranker_model, _model_load_attempted
|
|
28
|
-
|
|
29
|
-
if _model_load_attempted:
|
|
30
|
-
return _reranker_model
|
|
31
|
-
|
|
32
|
-
_model_load_attempted = True
|
|
33
|
-
|
|
34
|
-
# Check if re-ranking is enabled
|
|
35
|
-
if os.getenv("CONTEXT_RERANK_ENABLED", "true").lower() != "true":
|
|
36
|
-
log.debug("Context re-ranking disabled via CONTEXT_RERANK_ENABLED")
|
|
37
|
-
return None
|
|
38
|
-
|
|
39
|
-
try:
|
|
40
|
-
from sentence_transformers import CrossEncoder
|
|
41
|
-
|
|
42
|
-
model_name = os.getenv(
|
|
43
|
-
"CONTEXT_RERANK_MODEL", "mixedbread-ai/mxbai-rerank-xsmall-v1"
|
|
44
|
-
)
|
|
45
|
-
log.info(f"Loading re-ranker model: {model_name}")
|
|
46
|
-
_reranker_model = CrossEncoder(model_name)
|
|
47
|
-
log.info("Re-ranker model loaded successfully")
|
|
48
|
-
return _reranker_model
|
|
49
|
-
except ImportError:
|
|
50
|
-
log.warning("sentence-transformers not installed, re-ranking disabled")
|
|
51
|
-
return None
|
|
52
|
-
except Exception as e:
|
|
53
|
-
log.warning(f"Failed to load re-ranker model: {e}")
|
|
54
|
-
return None
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
def item_to_text(item: ContextItem) -> str:
|
|
58
|
-
"""Convert a ContextItem to text for re-ranking.
|
|
59
|
-
|
|
60
|
-
Args:
|
|
61
|
-
item: Context item to convert
|
|
62
|
-
|
|
63
|
-
Returns:
|
|
64
|
-
Text representation for scoring
|
|
65
|
-
"""
|
|
66
|
-
parts = [item.qualified_name]
|
|
67
|
-
|
|
68
|
-
if item.entity_type:
|
|
69
|
-
parts.append(f"({item.entity_type})")
|
|
70
|
-
|
|
71
|
-
if item.description:
|
|
72
|
-
parts.append(f": {item.description[:200]}")
|
|
73
|
-
|
|
74
|
-
if item.file_path:
|
|
75
|
-
# Just include the filename, not full path
|
|
76
|
-
filename = os.path.basename(item.file_path)
|
|
77
|
-
parts.append(f" [file: {filename}]")
|
|
78
|
-
|
|
79
|
-
return " ".join(parts)
|
|
80
|
-
|
|
81
20
|
|
|
82
21
|
def rerank_context_items(
|
|
83
22
|
items: list[ContextItem],
|
|
84
23
|
query: str,
|
|
85
24
|
top_k: Optional[int] = None,
|
|
86
25
|
top_percent: Optional[float] = None,
|
|
26
|
+
connection=None,
|
|
87
27
|
) -> list[ContextItem]:
|
|
88
28
|
"""Re-rank context items by relevance to query.
|
|
89
29
|
|
|
90
|
-
Uses a cross-encoder model to score each item against the query,
|
|
91
|
-
then returns the top K or top N% most relevant items.
|
|
92
|
-
|
|
93
30
|
Args:
|
|
94
31
|
items: List of context items to re-rank
|
|
95
32
|
query: The user's query/task description
|
|
96
33
|
top_k: Keep top K items (default from env: CONTEXT_RERANK_TOP_K=20)
|
|
97
34
|
top_percent: Keep top N% items (overrides top_k if set)
|
|
35
|
+
connection: Optional Kuzu connection for graph-based scoring
|
|
98
36
|
|
|
99
37
|
Returns:
|
|
100
38
|
Filtered and sorted list of context items (most relevant first)
|
|
101
39
|
"""
|
|
102
|
-
import time
|
|
103
|
-
|
|
104
|
-
original_count = len(items)
|
|
105
|
-
|
|
106
40
|
if not items:
|
|
107
41
|
return items
|
|
108
42
|
|
|
@@ -110,63 +44,31 @@ def rerank_context_items(
|
|
|
110
44
|
log.debug("No query provided for re-ranking, returning original items")
|
|
111
45
|
return items
|
|
112
46
|
|
|
113
|
-
|
|
114
|
-
if
|
|
115
|
-
log.debug("
|
|
47
|
+
# Check if re-ranking is enabled
|
|
48
|
+
if os.getenv("CONTEXT_RERANK_ENABLED", "true").lower() != "true":
|
|
49
|
+
log.debug("Context re-ranking disabled via CONTEXT_RERANK_ENABLED")
|
|
116
50
|
return items
|
|
117
51
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
# Sort by score descending
|
|
134
|
-
scored_items.sort(key=lambda x: x[1], reverse=True)
|
|
135
|
-
|
|
136
|
-
# Determine how many to keep
|
|
137
|
-
if top_percent is not None:
|
|
138
|
-
keep_count = max(1, int(len(items) * top_percent))
|
|
139
|
-
elif top_k is not None:
|
|
140
|
-
keep_count = min(top_k, len(items))
|
|
141
|
-
else:
|
|
142
|
-
# Default from environment
|
|
143
|
-
default_top_k = int(os.getenv("CONTEXT_RERANK_TOP_K", "20"))
|
|
144
|
-
keep_count = min(default_top_k, len(items))
|
|
145
|
-
|
|
146
|
-
duration_ms = (time.time() - start_time) * 1000
|
|
147
|
-
|
|
148
|
-
# Log statistics
|
|
149
|
-
if scored_items:
|
|
150
|
-
max_score = scored_items[0][1]
|
|
151
|
-
min_score = scored_items[-1][1]
|
|
152
|
-
filtered_count = original_count - keep_count
|
|
153
|
-
log.info(
|
|
154
|
-
f"Re-ranked context: {original_count} -> {keep_count} items "
|
|
155
|
-
f"(filtered {filtered_count}) in {duration_ms:.0f}ms | "
|
|
156
|
-
f"scores [{min_score:.3f}-{max_score:.3f}] | "
|
|
157
|
-
f"query: '{query[:40]}...'"
|
|
158
|
-
)
|
|
159
|
-
|
|
160
|
-
# Return top items (without scores)
|
|
161
|
-
return [item for item, score in scored_items[:keep_count]]
|
|
162
|
-
|
|
163
|
-
except Exception as e:
|
|
164
|
-
log.warning(f"Re-ranking failed: {e}, returning original items")
|
|
165
|
-
return items
|
|
52
|
+
# Determine effective top_k
|
|
53
|
+
if top_percent is not None:
|
|
54
|
+
effective_top_k = max(1, int(len(items) * top_percent))
|
|
55
|
+
elif top_k is not None:
|
|
56
|
+
effective_top_k = min(top_k, len(items))
|
|
57
|
+
else:
|
|
58
|
+
effective_top_k = int(os.getenv("CONTEXT_RERANK_TOP_K", "20"))
|
|
59
|
+
|
|
60
|
+
return simple_rerank_items(
|
|
61
|
+
items=items,
|
|
62
|
+
query=query,
|
|
63
|
+
connection=connection,
|
|
64
|
+
top_k=effective_top_k,
|
|
65
|
+
)
|
|
166
66
|
|
|
167
67
|
|
|
168
68
|
def get_rerank_scores(
|
|
169
|
-
items: list[ContextItem],
|
|
69
|
+
items: list[ContextItem],
|
|
70
|
+
query: str,
|
|
71
|
+
connection=None,
|
|
170
72
|
) -> list[tuple[ContextItem, float]]:
|
|
171
73
|
"""Get re-rank scores for context items without filtering.
|
|
172
74
|
|
|
@@ -175,6 +77,7 @@ def get_rerank_scores(
|
|
|
175
77
|
Args:
|
|
176
78
|
items: List of context items
|
|
177
79
|
query: Query to score against
|
|
80
|
+
connection: Optional Kuzu connection for graph signals
|
|
178
81
|
|
|
179
82
|
Returns:
|
|
180
83
|
List of (item, score) tuples sorted by score descending
|
|
@@ -182,18 +85,6 @@ def get_rerank_scores(
|
|
|
182
85
|
if not items or not query:
|
|
183
86
|
return [(item, 0.0) for item in items]
|
|
184
87
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
try:
|
|
190
|
-
texts = [item_to_text(item) for item in items]
|
|
191
|
-
pairs = [(query, text) for text in texts]
|
|
192
|
-
scores = model.predict(pairs)
|
|
193
|
-
|
|
194
|
-
scored = list(zip(items, scores))
|
|
195
|
-
scored.sort(key=lambda x: x[1], reverse=True)
|
|
196
|
-
return scored
|
|
197
|
-
except Exception as e:
|
|
198
|
-
log.warning(f"Failed to get rerank scores: {e}")
|
|
199
|
-
return [(item, 0.0) for item in items]
|
|
88
|
+
scored = get_simple_rerank_scores(items, query, connection)
|
|
89
|
+
# Return without component breakdown
|
|
90
|
+
return [(item, score) for item, score, _ in scored]
|