roampal 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- roampal/__init__.py +29 -0
- roampal/__main__.py +6 -0
- roampal/backend/__init__.py +1 -0
- roampal/backend/modules/__init__.py +1 -0
- roampal/backend/modules/memory/__init__.py +43 -0
- roampal/backend/modules/memory/chromadb_adapter.py +623 -0
- roampal/backend/modules/memory/config.py +102 -0
- roampal/backend/modules/memory/content_graph.py +543 -0
- roampal/backend/modules/memory/context_service.py +455 -0
- roampal/backend/modules/memory/embedding_service.py +96 -0
- roampal/backend/modules/memory/knowledge_graph_service.py +1052 -0
- roampal/backend/modules/memory/memory_bank_service.py +433 -0
- roampal/backend/modules/memory/memory_types.py +296 -0
- roampal/backend/modules/memory/outcome_service.py +400 -0
- roampal/backend/modules/memory/promotion_service.py +473 -0
- roampal/backend/modules/memory/routing_service.py +444 -0
- roampal/backend/modules/memory/scoring_service.py +324 -0
- roampal/backend/modules/memory/search_service.py +646 -0
- roampal/backend/modules/memory/tests/__init__.py +1 -0
- roampal/backend/modules/memory/tests/conftest.py +12 -0
- roampal/backend/modules/memory/tests/unit/__init__.py +1 -0
- roampal/backend/modules/memory/tests/unit/conftest.py +7 -0
- roampal/backend/modules/memory/tests/unit/test_knowledge_graph_service.py +517 -0
- roampal/backend/modules/memory/tests/unit/test_memory_bank_service.py +504 -0
- roampal/backend/modules/memory/tests/unit/test_outcome_service.py +485 -0
- roampal/backend/modules/memory/tests/unit/test_scoring_service.py +255 -0
- roampal/backend/modules/memory/tests/unit/test_search_service.py +413 -0
- roampal/backend/modules/memory/tests/unit/test_unified_memory_system.py +418 -0
- roampal/backend/modules/memory/unified_memory_system.py +1277 -0
- roampal/cli.py +638 -0
- roampal/hooks/__init__.py +16 -0
- roampal/hooks/session_manager.py +587 -0
- roampal/hooks/stop_hook.py +176 -0
- roampal/hooks/user_prompt_submit_hook.py +103 -0
- roampal/mcp/__init__.py +7 -0
- roampal/mcp/server.py +611 -0
- roampal/server/__init__.py +7 -0
- roampal/server/main.py +744 -0
- roampal-0.1.4.dist-info/METADATA +179 -0
- roampal-0.1.4.dist-info/RECORD +44 -0
- roampal-0.1.4.dist-info/WHEEL +5 -0
- roampal-0.1.4.dist-info/entry_points.txt +2 -0
- roampal-0.1.4.dist-info/licenses/LICENSE +190 -0
- roampal-0.1.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,646 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Search Service - Unified search with hybrid ranking and cross-encoder reranking.
|
|
3
|
+
|
|
4
|
+
Extracted from UnifiedMemorySystem as part of refactoring.
|
|
5
|
+
|
|
6
|
+
Responsibilities:
|
|
7
|
+
- Main search with hybrid ranking (vector + BM25)
|
|
8
|
+
- Cross-encoder reranking (optional)
|
|
9
|
+
- Entity boost calculation
|
|
10
|
+
- Result scoring and ranking
|
|
11
|
+
- Document effectiveness tracking
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import json
|
|
15
|
+
import logging
|
|
16
|
+
import math
|
|
17
|
+
from datetime import datetime
|
|
18
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
|
19
|
+
|
|
20
|
+
from .config import MemoryConfig
|
|
21
|
+
from .scoring_service import ScoringService, wilson_score_lower
|
|
22
|
+
from .routing_service import RoutingService, ALL_COLLECTIONS
|
|
23
|
+
from .knowledge_graph_service import KnowledgeGraphService
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# Type aliases
|
|
29
|
+
CollectionName = Literal["working", "patterns", "history", "books", "memory_bank"]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SearchService:
|
|
33
|
+
"""
|
|
34
|
+
Unified search with hybrid ranking.
|
|
35
|
+
|
|
36
|
+
Features:
|
|
37
|
+
- KG-based intelligent routing
|
|
38
|
+
- Vector similarity search with BM25 fusion
|
|
39
|
+
- Wilson score learning-aware ranking
|
|
40
|
+
- Cross-encoder reranking (optional)
|
|
41
|
+
- Entity boost from Content KG
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
collections: Dict[str, Any], # ChromaDBAdapter instances
|
|
47
|
+
scoring_service: ScoringService,
|
|
48
|
+
routing_service: RoutingService,
|
|
49
|
+
kg_service: KnowledgeGraphService,
|
|
50
|
+
embed_fn: Callable[[str], Any], # Async function to embed text
|
|
51
|
+
config: Optional[MemoryConfig] = None,
|
|
52
|
+
reranker: Optional[Any] = None, # CrossEncoder instance
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
Initialize SearchService.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
collections: Dict mapping collection name to ChromaDBAdapter
|
|
59
|
+
scoring_service: ScoringService for ranking
|
|
60
|
+
routing_service: RoutingService for query routing
|
|
61
|
+
kg_service: KnowledgeGraphService for KG operations
|
|
62
|
+
embed_fn: Async function to generate embeddings
|
|
63
|
+
config: Optional MemoryConfig
|
|
64
|
+
reranker: Optional CrossEncoder for reranking
|
|
65
|
+
"""
|
|
66
|
+
self.collections = collections
|
|
67
|
+
self.scoring_service = scoring_service
|
|
68
|
+
self.routing_service = routing_service
|
|
69
|
+
self.kg_service = kg_service
|
|
70
|
+
self.embed_fn = embed_fn
|
|
71
|
+
self.config = config or MemoryConfig()
|
|
72
|
+
self.reranker = reranker
|
|
73
|
+
|
|
74
|
+
# Cache for doc_ids per session (for outcome scoring)
|
|
75
|
+
self._cached_doc_ids: Dict[str, List[str]] = {}
|
|
76
|
+
|
|
77
|
+
# =========================================================================
|
|
78
|
+
# Main Search
|
|
79
|
+
# =========================================================================
|
|
80
|
+
|
|
81
|
+
async def search(
|
|
82
|
+
self,
|
|
83
|
+
query: str,
|
|
84
|
+
limit: int = 10,
|
|
85
|
+
offset: int = 0,
|
|
86
|
+
collections: Optional[List[CollectionName]] = None,
|
|
87
|
+
metadata_filters: Optional[Dict[str, Any]] = None,
|
|
88
|
+
return_metadata: bool = False,
|
|
89
|
+
transparency_context: Optional[Any] = None
|
|
90
|
+
) -> Union[List[Dict], Dict]:
|
|
91
|
+
"""
|
|
92
|
+
Search memory with intelligent routing and optional metadata filtering.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
query: Search query
|
|
96
|
+
limit: Max results
|
|
97
|
+
offset: Pagination offset
|
|
98
|
+
collections: Override automatic routing
|
|
99
|
+
metadata_filters: ChromaDB where filters
|
|
100
|
+
return_metadata: Include pagination metadata
|
|
101
|
+
transparency_context: Optional context for tracking
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Ranked results (list or dict with pagination metadata)
|
|
105
|
+
"""
|
|
106
|
+
# Use KG to route query if collections not specified
|
|
107
|
+
if collections is None:
|
|
108
|
+
collections = self.routing_service.route_query(query)
|
|
109
|
+
|
|
110
|
+
# Check for known problem->solution patterns
|
|
111
|
+
known_solutions = await self.kg_service.find_known_solutions(
|
|
112
|
+
query,
|
|
113
|
+
self._get_fragment
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Special handling for empty query - return all items
|
|
117
|
+
if not query or query.strip() == "":
|
|
118
|
+
return await self._search_all(
|
|
119
|
+
collections, limit, offset, return_metadata
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Preprocess query for better retrieval
|
|
123
|
+
processed_query = self.routing_service.preprocess_query(query)
|
|
124
|
+
|
|
125
|
+
# Generate query embedding
|
|
126
|
+
try:
|
|
127
|
+
query_embedding = await self.embed_fn(processed_query)
|
|
128
|
+
except Exception as e:
|
|
129
|
+
logger.error(f"Embedding generation failed for query '{query}': {e}")
|
|
130
|
+
if return_metadata:
|
|
131
|
+
return {"results": [], "total": 0, "limit": limit, "offset": offset, "has_more": False}
|
|
132
|
+
return []
|
|
133
|
+
|
|
134
|
+
# Track search start if context provided
|
|
135
|
+
if transparency_context and hasattr(transparency_context, 'track_action'):
|
|
136
|
+
transparency_context.track_action(
|
|
137
|
+
action_type="memory_search",
|
|
138
|
+
description=f"Searching: {query[:50]}{'...' if len(query) > 50 else ''}",
|
|
139
|
+
detail=f"Collections: {', '.join(collections)}",
|
|
140
|
+
status="executing"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Search specified collections
|
|
144
|
+
all_results = await self._search_collections(
|
|
145
|
+
query_embedding, processed_query, collections, limit, metadata_filters
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Add known solutions to the beginning (they're already boosted)
|
|
149
|
+
if known_solutions:
|
|
150
|
+
existing_ids = {r.get("id") for r in all_results}
|
|
151
|
+
unique_known = [s for s in known_solutions if s.get("id") not in existing_ids]
|
|
152
|
+
all_results = unique_known + all_results
|
|
153
|
+
|
|
154
|
+
# Apply scoring and ranking
|
|
155
|
+
all_results = self.scoring_service.apply_scoring_to_results(all_results)
|
|
156
|
+
|
|
157
|
+
# Cross-encoder reranking
|
|
158
|
+
if self.reranker and len(all_results) > limit * 2:
|
|
159
|
+
all_results = await self._rerank_with_cross_encoder(query, all_results, limit)
|
|
160
|
+
|
|
161
|
+
# Track usage for KG learning
|
|
162
|
+
paginated_results = all_results[offset:offset + limit]
|
|
163
|
+
self._track_search_results(query, paginated_results, transparency_context)
|
|
164
|
+
|
|
165
|
+
# Return results
|
|
166
|
+
if return_metadata:
|
|
167
|
+
return {
|
|
168
|
+
"results": paginated_results,
|
|
169
|
+
"total": len(all_results),
|
|
170
|
+
"limit": limit,
|
|
171
|
+
"offset": offset,
|
|
172
|
+
"has_more": (offset + limit) < len(all_results)
|
|
173
|
+
}
|
|
174
|
+
return paginated_results
|
|
175
|
+
|
|
176
|
+
async def _search_all(
|
|
177
|
+
self,
|
|
178
|
+
collections: List[str],
|
|
179
|
+
limit: int,
|
|
180
|
+
offset: int,
|
|
181
|
+
return_metadata: bool
|
|
182
|
+
) -> Union[List[Dict], Dict]:
|
|
183
|
+
"""Handle empty query - return all items."""
|
|
184
|
+
all_results = []
|
|
185
|
+
|
|
186
|
+
for coll_name in collections:
|
|
187
|
+
if coll_name not in self.collections:
|
|
188
|
+
continue
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
adapter = self.collections[coll_name]
|
|
192
|
+
collection_obj = adapter.collection
|
|
193
|
+
items = collection_obj.get(limit=100000)
|
|
194
|
+
|
|
195
|
+
for i in range(len(items['ids'])):
|
|
196
|
+
metadata = items['metadatas'][i] if i < len(items['metadatas']) else {}
|
|
197
|
+
result = {
|
|
198
|
+
'id': items['ids'][i],
|
|
199
|
+
'content': items['documents'][i] if i < len(items['documents']) else '',
|
|
200
|
+
'text': items['documents'][i] if i < len(items['documents']) else '',
|
|
201
|
+
'metadata': metadata,
|
|
202
|
+
'collection': coll_name
|
|
203
|
+
}
|
|
204
|
+
if 'score' in metadata:
|
|
205
|
+
result['score'] = metadata['score']
|
|
206
|
+
all_results.append(result)
|
|
207
|
+
except Exception as e:
|
|
208
|
+
logger.error(f"Error getting all items from {coll_name}: {e}")
|
|
209
|
+
|
|
210
|
+
# Sort by timestamp
|
|
211
|
+
all_results.sort(key=lambda x: x.get('metadata', {}).get('timestamp', ''), reverse=True)
|
|
212
|
+
|
|
213
|
+
paginated_results = all_results[offset:offset + limit]
|
|
214
|
+
if return_metadata:
|
|
215
|
+
return {
|
|
216
|
+
"results": paginated_results,
|
|
217
|
+
"total": len(all_results),
|
|
218
|
+
"limit": limit,
|
|
219
|
+
"offset": offset,
|
|
220
|
+
"has_more": (offset + limit) < len(all_results)
|
|
221
|
+
}
|
|
222
|
+
return paginated_results
|
|
223
|
+
|
|
224
|
+
async def _search_collections(
|
|
225
|
+
self,
|
|
226
|
+
query_embedding: List[float],
|
|
227
|
+
processed_query: str,
|
|
228
|
+
collections: List[str],
|
|
229
|
+
limit: int,
|
|
230
|
+
metadata_filters: Optional[Dict[str, Any]]
|
|
231
|
+
) -> List[Dict]:
|
|
232
|
+
"""Search specified collections and apply boosts."""
|
|
233
|
+
all_results = []
|
|
234
|
+
|
|
235
|
+
for coll_name in collections:
|
|
236
|
+
if coll_name not in self.collections:
|
|
237
|
+
continue
|
|
238
|
+
|
|
239
|
+
results = await self._search_single_collection(
|
|
240
|
+
coll_name, query_embedding, processed_query, limit, metadata_filters
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# Apply collection-specific boosts
|
|
244
|
+
for r in results:
|
|
245
|
+
r["collection"] = coll_name
|
|
246
|
+
self._apply_collection_boost(r, coll_name, processed_query)
|
|
247
|
+
|
|
248
|
+
all_results.extend(results)
|
|
249
|
+
|
|
250
|
+
return all_results
|
|
251
|
+
|
|
252
|
+
async def _search_single_collection(
|
|
253
|
+
self,
|
|
254
|
+
coll_name: str,
|
|
255
|
+
query_embedding: List[float],
|
|
256
|
+
processed_query: str,
|
|
257
|
+
limit: int,
|
|
258
|
+
metadata_filters: Optional[Dict[str, Any]]
|
|
259
|
+
) -> List[Dict]:
|
|
260
|
+
"""Search a single collection."""
|
|
261
|
+
adapter = self.collections[coll_name]
|
|
262
|
+
multiplier = self.config.search_multiplier
|
|
263
|
+
|
|
264
|
+
# Build filters for memory_bank
|
|
265
|
+
filters = metadata_filters
|
|
266
|
+
if coll_name == "memory_bank":
|
|
267
|
+
filters = (metadata_filters or {}).copy()
|
|
268
|
+
if "status" not in filters:
|
|
269
|
+
filters["status"] = {"$ne": "archived"}
|
|
270
|
+
|
|
271
|
+
# Hybrid query (vector + BM25)
|
|
272
|
+
results = await adapter.hybrid_query(
|
|
273
|
+
query_vector=query_embedding,
|
|
274
|
+
query_text=processed_query,
|
|
275
|
+
top_k=limit * multiplier,
|
|
276
|
+
filters=filters
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# Add recency metadata for working memory
|
|
280
|
+
if coll_name == "working":
|
|
281
|
+
self._add_recency_metadata(results)
|
|
282
|
+
|
|
283
|
+
return results
|
|
284
|
+
|
|
285
|
+
def _apply_collection_boost(self, result: Dict, coll_name: str, query: str):
|
|
286
|
+
"""Apply collection-specific distance boosts."""
|
|
287
|
+
# Patterns get slight boost
|
|
288
|
+
if coll_name == "patterns":
|
|
289
|
+
result["distance"] = result.get("distance", 1.0) * 0.9
|
|
290
|
+
|
|
291
|
+
# Memory bank: boost by importance * confidence
|
|
292
|
+
elif coll_name == "memory_bank":
|
|
293
|
+
metadata = result.get("metadata", {})
|
|
294
|
+
importance = self._parse_numeric(metadata.get("importance", 0.7))
|
|
295
|
+
confidence = self._parse_numeric(metadata.get("confidence", 0.7))
|
|
296
|
+
quality_score = importance * confidence
|
|
297
|
+
|
|
298
|
+
# Quality boost
|
|
299
|
+
metadata_boost = 1.0 - quality_score * 0.8
|
|
300
|
+
entity_boost = self._calculate_entity_boost(query, result.get("id", ""))
|
|
301
|
+
result["distance"] = result.get("distance", 1.0) * metadata_boost / entity_boost
|
|
302
|
+
|
|
303
|
+
# Doc effectiveness boost
|
|
304
|
+
doc_id = result.get("id") or result.get("doc_id")
|
|
305
|
+
if doc_id:
|
|
306
|
+
eff = self.get_doc_effectiveness(doc_id)
|
|
307
|
+
if eff and eff.get("total_uses", 0) >= 3:
|
|
308
|
+
eff_multiplier = 0.7 + eff["success_rate"] * 0.6
|
|
309
|
+
result["distance"] = result["distance"] / eff_multiplier
|
|
310
|
+
|
|
311
|
+
# Books: boost recent uploads
|
|
312
|
+
elif coll_name == "books":
|
|
313
|
+
if result.get("upload_timestamp"):
|
|
314
|
+
try:
|
|
315
|
+
upload_time = datetime.fromisoformat(result["upload_timestamp"])
|
|
316
|
+
age_days = (datetime.utcnow() - upload_time).days
|
|
317
|
+
if age_days <= 7:
|
|
318
|
+
result["distance"] = result.get("distance", 1.0) * 0.7
|
|
319
|
+
except Exception:
|
|
320
|
+
pass
|
|
321
|
+
|
|
322
|
+
# Doc effectiveness boost
|
|
323
|
+
doc_id = result.get("id") or result.get("doc_id")
|
|
324
|
+
if doc_id:
|
|
325
|
+
eff = self.get_doc_effectiveness(doc_id)
|
|
326
|
+
if eff and eff.get("total_uses", 0) >= 3:
|
|
327
|
+
eff_multiplier = 0.7 + eff["success_rate"] * 0.6
|
|
328
|
+
result["distance"] = result.get("distance", 1.0) / eff_multiplier
|
|
329
|
+
|
|
330
|
+
def _add_recency_metadata(self, results: List[Dict]):
|
|
331
|
+
"""Add recency metadata to working memory results."""
|
|
332
|
+
for r in results:
|
|
333
|
+
metadata = r.get("metadata", {})
|
|
334
|
+
if metadata.get("timestamp"):
|
|
335
|
+
try:
|
|
336
|
+
timestamp = datetime.fromisoformat(metadata["timestamp"])
|
|
337
|
+
minutes_ago = (datetime.now() - timestamp).total_seconds() / 60
|
|
338
|
+
|
|
339
|
+
if minutes_ago < 1:
|
|
340
|
+
metadata["recency"] = "just now"
|
|
341
|
+
elif minutes_ago < 60:
|
|
342
|
+
metadata["recency"] = f"{int(minutes_ago)} minutes ago"
|
|
343
|
+
else:
|
|
344
|
+
hours_ago = minutes_ago / 60
|
|
345
|
+
metadata["recency"] = f"{int(hours_ago)} hours ago"
|
|
346
|
+
|
|
347
|
+
r["minutes_ago"] = minutes_ago
|
|
348
|
+
except Exception:
|
|
349
|
+
r["minutes_ago"] = 999
|
|
350
|
+
|
|
351
|
+
def _parse_numeric(self, value: Any) -> float:
|
|
352
|
+
"""Parse value to float, handling various formats."""
|
|
353
|
+
if isinstance(value, (list, tuple)):
|
|
354
|
+
return float(value[0]) if value else 0.7
|
|
355
|
+
elif isinstance(value, str):
|
|
356
|
+
level_map = {'high': 0.9, 'medium': 0.7, 'low': 0.5}
|
|
357
|
+
return level_map.get(value.lower(), 0.7)
|
|
358
|
+
try:
|
|
359
|
+
return float(value) if value else 0.7
|
|
360
|
+
except (ValueError, TypeError):
|
|
361
|
+
return 0.7
|
|
362
|
+
|
|
363
|
+
# =========================================================================
|
|
364
|
+
# Entity Boost Calculation
|
|
365
|
+
# =========================================================================
|
|
366
|
+
|
|
367
|
+
def _calculate_entity_boost(self, query: str, doc_id: str) -> float:
|
|
368
|
+
"""
|
|
369
|
+
Calculate quality boost based on Content KG entities.
|
|
370
|
+
|
|
371
|
+
Only applies to memory_bank searches - boosts documents containing
|
|
372
|
+
high-quality entities that match query concepts.
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
Boost multiplier (1.0 = no boost, up to 1.5 = 50% boost)
|
|
376
|
+
"""
|
|
377
|
+
try:
|
|
378
|
+
query_concepts = self.kg_service.extract_concepts(query)
|
|
379
|
+
query_entities = [c for c in query_concepts if len(c) >= 3]
|
|
380
|
+
|
|
381
|
+
if not query_entities:
|
|
382
|
+
return 1.0
|
|
383
|
+
|
|
384
|
+
doc_entities = self.kg_service.content_graph._doc_entities.get(doc_id, set())
|
|
385
|
+
|
|
386
|
+
if not doc_entities:
|
|
387
|
+
return 1.0
|
|
388
|
+
|
|
389
|
+
total_boost = 0.0
|
|
390
|
+
for entity in query_entities:
|
|
391
|
+
if entity in doc_entities and entity in self.kg_service.content_graph.entities:
|
|
392
|
+
entity_quality = self.kg_service.content_graph.entities[entity].get("avg_quality", 0.0)
|
|
393
|
+
total_boost += entity_quality
|
|
394
|
+
|
|
395
|
+
# Cap boost at 50%
|
|
396
|
+
boost_multiplier = 1.0 + min(total_boost * 0.2, 0.5)
|
|
397
|
+
|
|
398
|
+
if boost_multiplier > 1.0:
|
|
399
|
+
logger.debug(f"Entity boost for {doc_id}: {boost_multiplier:.2f}x")
|
|
400
|
+
|
|
401
|
+
return boost_multiplier
|
|
402
|
+
except Exception as e:
|
|
403
|
+
logger.error(f"Error calculating entity boost: {e}")
|
|
404
|
+
return 1.0
|
|
405
|
+
|
|
406
|
+
# =========================================================================
|
|
407
|
+
# Document Effectiveness
|
|
408
|
+
# =========================================================================
|
|
409
|
+
|
|
410
|
+
def get_doc_effectiveness(self, doc_id: str) -> Optional[Dict[str, Any]]:
|
|
411
|
+
"""
|
|
412
|
+
Aggregate success rate for a specific doc from Action KG examples.
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
Dict with success_rate, total_uses, etc., or None if no data
|
|
416
|
+
"""
|
|
417
|
+
successes = 0
|
|
418
|
+
failures = 0
|
|
419
|
+
partials = 0
|
|
420
|
+
|
|
421
|
+
for key, stats in self.kg_service.knowledge_graph.get("context_action_effectiveness", {}).items():
|
|
422
|
+
for example in stats.get("examples", []):
|
|
423
|
+
if example.get("doc_id") == doc_id:
|
|
424
|
+
if example["outcome"] == "worked":
|
|
425
|
+
successes += 1
|
|
426
|
+
elif example["outcome"] == "failed":
|
|
427
|
+
failures += 1
|
|
428
|
+
else:
|
|
429
|
+
partials += 1
|
|
430
|
+
|
|
431
|
+
total = successes + failures + partials
|
|
432
|
+
if total == 0:
|
|
433
|
+
return None
|
|
434
|
+
|
|
435
|
+
return {
|
|
436
|
+
"successes": successes,
|
|
437
|
+
"failures": failures,
|
|
438
|
+
"partials": partials,
|
|
439
|
+
"total_uses": total,
|
|
440
|
+
"success_rate": (successes + partials * 0.5) / total
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
# =========================================================================
|
|
444
|
+
# Cross-Encoder Reranking
|
|
445
|
+
# =========================================================================
|
|
446
|
+
|
|
447
|
+
async def _rerank_with_cross_encoder(
|
|
448
|
+
self,
|
|
449
|
+
query: str,
|
|
450
|
+
candidates: List[Dict],
|
|
451
|
+
top_k: int
|
|
452
|
+
) -> List[Dict]:
|
|
453
|
+
"""
|
|
454
|
+
Rerank top candidates with cross-encoder for precision.
|
|
455
|
+
|
|
456
|
+
Returns:
|
|
457
|
+
Reranked results with cross-encoder scores
|
|
458
|
+
"""
|
|
459
|
+
if not self.reranker:
|
|
460
|
+
return candidates
|
|
461
|
+
|
|
462
|
+
try:
|
|
463
|
+
# Take top-30 candidates for reranking
|
|
464
|
+
top_candidates = sorted(
|
|
465
|
+
candidates,
|
|
466
|
+
key=lambda x: x.get("final_rank_score", 0.0),
|
|
467
|
+
reverse=True
|
|
468
|
+
)[:30]
|
|
469
|
+
|
|
470
|
+
# Prepare query-document pairs
|
|
471
|
+
pairs = []
|
|
472
|
+
for candidate in top_candidates:
|
|
473
|
+
doc_text = candidate.get("text", "")
|
|
474
|
+
if not doc_text and candidate.get("metadata"):
|
|
475
|
+
doc_text = candidate.get("metadata", {}).get("content", "")
|
|
476
|
+
pairs.append([query, doc_text[:512]])
|
|
477
|
+
|
|
478
|
+
# Score with cross-encoder
|
|
479
|
+
ce_scores = self.reranker.predict(pairs, batch_size=32, show_progress_bar=False)
|
|
480
|
+
|
|
481
|
+
# Blend scores
|
|
482
|
+
for i, candidate in enumerate(top_candidates):
|
|
483
|
+
ce_score = float(ce_scores[i])
|
|
484
|
+
candidate["ce_score"] = ce_score
|
|
485
|
+
original_score = candidate.get("final_rank_score", 0.5)
|
|
486
|
+
|
|
487
|
+
collection = candidate.get("collection", "")
|
|
488
|
+
if collection == "memory_bank":
|
|
489
|
+
metadata = candidate.get("metadata", {})
|
|
490
|
+
importance = self._parse_numeric(metadata.get("importance", 0.7))
|
|
491
|
+
confidence = self._parse_numeric(metadata.get("confidence", 0.7))
|
|
492
|
+
quality = importance * confidence
|
|
493
|
+
|
|
494
|
+
ce_norm = (ce_score + 1) / 2
|
|
495
|
+
ce_weight = 0.3
|
|
496
|
+
quality_boost = 1.0 + quality * 0.3
|
|
497
|
+
blended = ((1 - ce_weight) * original_score + ce_weight * ce_norm) * quality_boost
|
|
498
|
+
else:
|
|
499
|
+
ce_norm = (ce_score + 1) / 2
|
|
500
|
+
ce_weight = 0.4
|
|
501
|
+
blended = (1 - ce_weight) * original_score + ce_weight * ce_norm
|
|
502
|
+
|
|
503
|
+
candidate["final_rank_score"] = blended
|
|
504
|
+
|
|
505
|
+
# Re-sort by updated score
|
|
506
|
+
top_candidates.sort(key=lambda x: x.get("final_rank_score", 0.0), reverse=True)
|
|
507
|
+
|
|
508
|
+
# Merge back: use reranked top + remaining
|
|
509
|
+
top_ids = {c.get("id") for c in top_candidates}
|
|
510
|
+
remaining = [c for c in candidates if c.get("id") not in top_ids]
|
|
511
|
+
|
|
512
|
+
return top_candidates + remaining
|
|
513
|
+
|
|
514
|
+
except Exception as e:
|
|
515
|
+
logger.error(f"Cross-encoder reranking failed: {e}")
|
|
516
|
+
return candidates
|
|
517
|
+
|
|
518
|
+
# =========================================================================
|
|
519
|
+
# Tracking and Caching
|
|
520
|
+
# =========================================================================
|
|
521
|
+
|
|
522
|
+
def _track_search_results(
|
|
523
|
+
self,
|
|
524
|
+
query: str,
|
|
525
|
+
results: List[Dict],
|
|
526
|
+
transparency_context: Optional[Any]
|
|
527
|
+
):
|
|
528
|
+
"""Track search results for KG learning."""
|
|
529
|
+
# Track usage for returned results
|
|
530
|
+
for result in results:
|
|
531
|
+
if "collection" in result and "id" in result:
|
|
532
|
+
self._track_usage(query, result["collection"], result["id"])
|
|
533
|
+
|
|
534
|
+
# Cache doc_ids for outcome scoring
|
|
535
|
+
session_id = 'default'
|
|
536
|
+
if transparency_context and hasattr(transparency_context, 'session_id'):
|
|
537
|
+
session_id = transparency_context.session_id
|
|
538
|
+
|
|
539
|
+
cached_doc_ids = []
|
|
540
|
+
for result in results:
|
|
541
|
+
collection = result.get("collection", "")
|
|
542
|
+
if collection in ["working", "history", "patterns"]:
|
|
543
|
+
doc_id = result.get("id")
|
|
544
|
+
if doc_id:
|
|
545
|
+
cached_doc_ids.append(doc_id)
|
|
546
|
+
|
|
547
|
+
self._cached_doc_ids[session_id] = cached_doc_ids
|
|
548
|
+
if cached_doc_ids:
|
|
549
|
+
logger.debug(f"Cached {len(cached_doc_ids)} doc_ids for outcome scoring")
|
|
550
|
+
|
|
551
|
+
# Track with transparency context
|
|
552
|
+
if transparency_context and hasattr(transparency_context, 'track_memory_search'):
|
|
553
|
+
confidence_scores = [math.exp(-r.get("distance", 0.5) / 100.0) for r in results]
|
|
554
|
+
transparency_context.track_memory_search(
|
|
555
|
+
query=query,
|
|
556
|
+
collections=[r.get("collection") for r in results[:1]],
|
|
557
|
+
results_count=len(results),
|
|
558
|
+
confidence_scores=confidence_scores
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
def _track_usage(self, query: str, collection: str, doc_id: str):
|
|
562
|
+
"""Track which collection was used for which query."""
|
|
563
|
+
concepts = self.kg_service.extract_concepts(query)
|
|
564
|
+
for concept in concepts:
|
|
565
|
+
if concept not in self.kg_service.knowledge_graph["routing_patterns"]:
|
|
566
|
+
self.kg_service.knowledge_graph["routing_patterns"][concept] = {
|
|
567
|
+
"collections_used": {},
|
|
568
|
+
"best_collection": collection,
|
|
569
|
+
"success_rate": 0.5
|
|
570
|
+
}
|
|
571
|
+
|
|
572
|
+
def _get_fragment(self, collection: str, doc_id: str) -> Optional[Dict]:
|
|
573
|
+
"""Get a document fragment by collection and ID."""
|
|
574
|
+
if collection in self.collections:
|
|
575
|
+
return self.collections[collection].get_fragment(doc_id)
|
|
576
|
+
return None
|
|
577
|
+
|
|
578
|
+
def get_cached_doc_ids(self, session_id: str = 'default') -> List[str]:
|
|
579
|
+
"""Get cached doc_ids for a session."""
|
|
580
|
+
return self._cached_doc_ids.get(session_id, [])
|
|
581
|
+
|
|
582
|
+
# =========================================================================
|
|
583
|
+
# Book Search (Specialized)
|
|
584
|
+
# =========================================================================
|
|
585
|
+
|
|
586
|
+
async def search_books(
|
|
587
|
+
self,
|
|
588
|
+
query: str,
|
|
589
|
+
chunk_type: Optional[str] = None,
|
|
590
|
+
has_code: Optional[bool] = None,
|
|
591
|
+
code_language: Optional[str] = None,
|
|
592
|
+
n_results: int = 5
|
|
593
|
+
) -> List[Dict]:
|
|
594
|
+
"""
|
|
595
|
+
Enhanced book search with metadata filtering.
|
|
596
|
+
|
|
597
|
+
Args:
|
|
598
|
+
query: Search query text
|
|
599
|
+
chunk_type: Filter by chunk type ("code", "prose", "mixed")
|
|
600
|
+
has_code: Filter by presence of code
|
|
601
|
+
code_language: Filter by programming language
|
|
602
|
+
n_results: Number of results to return
|
|
603
|
+
|
|
604
|
+
Returns:
|
|
605
|
+
List of search results with enhanced context
|
|
606
|
+
"""
|
|
607
|
+
if "books" not in self.collections:
|
|
608
|
+
logger.warning("Books collection not initialized")
|
|
609
|
+
return []
|
|
610
|
+
|
|
611
|
+
# Build where clause
|
|
612
|
+
where = {}
|
|
613
|
+
if chunk_type:
|
|
614
|
+
where["chunk_type"] = chunk_type
|
|
615
|
+
if has_code is not None:
|
|
616
|
+
where["has_code"] = has_code
|
|
617
|
+
if code_language:
|
|
618
|
+
where["code_language"] = code_language
|
|
619
|
+
|
|
620
|
+
try:
|
|
621
|
+
results = await self.collections["books"].query(
|
|
622
|
+
query_texts=[query],
|
|
623
|
+
n_results=n_results * 2,
|
|
624
|
+
where=where if where else None
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
if not results or not results.get("ids"):
|
|
628
|
+
return []
|
|
629
|
+
|
|
630
|
+
# Format results
|
|
631
|
+
formatted_results = []
|
|
632
|
+
for i in range(min(n_results, len(results["ids"][0]))):
|
|
633
|
+
result = {
|
|
634
|
+
"id": results["ids"][0][i],
|
|
635
|
+
"text": results["documents"][0][i],
|
|
636
|
+
"metadata": results["metadatas"][0][i] if results.get("metadatas") else {},
|
|
637
|
+
"distance": results["distances"][0][i] if results.get("distances") else 0,
|
|
638
|
+
"collection": "books"
|
|
639
|
+
}
|
|
640
|
+
formatted_results.append(result)
|
|
641
|
+
|
|
642
|
+
return formatted_results
|
|
643
|
+
|
|
644
|
+
except Exception as e:
|
|
645
|
+
logger.error(f"Book search failed: {e}")
|
|
646
|
+
return []
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Memory module tests
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pytest configuration for memory module tests.
|
|
3
|
+
|
|
4
|
+
Sets up import paths for roampal-core.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import sys
|
|
8
|
+
import os
|
|
9
|
+
|
|
10
|
+
# Add roampal-core root to path
|
|
11
|
+
CORE_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..', '..', '..'))
|
|
12
|
+
sys.path.insert(0, CORE_ROOT)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Unit tests for memory module
|