roampal 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. roampal/__init__.py +29 -0
  2. roampal/__main__.py +6 -0
  3. roampal/backend/__init__.py +1 -0
  4. roampal/backend/modules/__init__.py +1 -0
  5. roampal/backend/modules/memory/__init__.py +43 -0
  6. roampal/backend/modules/memory/chromadb_adapter.py +623 -0
  7. roampal/backend/modules/memory/config.py +102 -0
  8. roampal/backend/modules/memory/content_graph.py +543 -0
  9. roampal/backend/modules/memory/context_service.py +455 -0
  10. roampal/backend/modules/memory/embedding_service.py +96 -0
  11. roampal/backend/modules/memory/knowledge_graph_service.py +1052 -0
  12. roampal/backend/modules/memory/memory_bank_service.py +433 -0
  13. roampal/backend/modules/memory/memory_types.py +296 -0
  14. roampal/backend/modules/memory/outcome_service.py +400 -0
  15. roampal/backend/modules/memory/promotion_service.py +473 -0
  16. roampal/backend/modules/memory/routing_service.py +444 -0
  17. roampal/backend/modules/memory/scoring_service.py +324 -0
  18. roampal/backend/modules/memory/search_service.py +646 -0
  19. roampal/backend/modules/memory/tests/__init__.py +1 -0
  20. roampal/backend/modules/memory/tests/conftest.py +12 -0
  21. roampal/backend/modules/memory/tests/unit/__init__.py +1 -0
  22. roampal/backend/modules/memory/tests/unit/conftest.py +7 -0
  23. roampal/backend/modules/memory/tests/unit/test_knowledge_graph_service.py +517 -0
  24. roampal/backend/modules/memory/tests/unit/test_memory_bank_service.py +504 -0
  25. roampal/backend/modules/memory/tests/unit/test_outcome_service.py +485 -0
  26. roampal/backend/modules/memory/tests/unit/test_scoring_service.py +255 -0
  27. roampal/backend/modules/memory/tests/unit/test_search_service.py +413 -0
  28. roampal/backend/modules/memory/tests/unit/test_unified_memory_system.py +418 -0
  29. roampal/backend/modules/memory/unified_memory_system.py +1277 -0
  30. roampal/cli.py +638 -0
  31. roampal/hooks/__init__.py +16 -0
  32. roampal/hooks/session_manager.py +587 -0
  33. roampal/hooks/stop_hook.py +176 -0
  34. roampal/hooks/user_prompt_submit_hook.py +103 -0
  35. roampal/mcp/__init__.py +7 -0
  36. roampal/mcp/server.py +611 -0
  37. roampal/server/__init__.py +7 -0
  38. roampal/server/main.py +744 -0
  39. roampal-0.1.4.dist-info/METADATA +179 -0
  40. roampal-0.1.4.dist-info/RECORD +44 -0
  41. roampal-0.1.4.dist-info/WHEEL +5 -0
  42. roampal-0.1.4.dist-info/entry_points.txt +2 -0
  43. roampal-0.1.4.dist-info/licenses/LICENSE +190 -0
  44. roampal-0.1.4.dist-info/top_level.txt +1 -0
@@ -0,0 +1,646 @@
1
+ """
2
+ Search Service - Unified search with hybrid ranking and cross-encoder reranking.
3
+
4
+ Extracted from UnifiedMemorySystem as part of refactoring.
5
+
6
+ Responsibilities:
7
+ - Main search with hybrid ranking (vector + BM25)
8
+ - Cross-encoder reranking (optional)
9
+ - Entity boost calculation
10
+ - Result scoring and ranking
11
+ - Document effectiveness tracking
12
+ """
13
+
14
+ import json
15
+ import logging
16
+ import math
17
+ from datetime import datetime
18
+ from typing import Any, Callable, Dict, List, Literal, Optional, Union
19
+
20
+ from .config import MemoryConfig
21
+ from .scoring_service import ScoringService, wilson_score_lower
22
+ from .routing_service import RoutingService, ALL_COLLECTIONS
23
+ from .knowledge_graph_service import KnowledgeGraphService
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ # Type aliases
29
+ CollectionName = Literal["working", "patterns", "history", "books", "memory_bank"]
30
+
31
+
32
+ class SearchService:
33
+ """
34
+ Unified search with hybrid ranking.
35
+
36
+ Features:
37
+ - KG-based intelligent routing
38
+ - Vector similarity search with BM25 fusion
39
+ - Wilson score learning-aware ranking
40
+ - Cross-encoder reranking (optional)
41
+ - Entity boost from Content KG
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ collections: Dict[str, Any], # ChromaDBAdapter instances
47
+ scoring_service: ScoringService,
48
+ routing_service: RoutingService,
49
+ kg_service: KnowledgeGraphService,
50
+ embed_fn: Callable[[str], Any], # Async function to embed text
51
+ config: Optional[MemoryConfig] = None,
52
+ reranker: Optional[Any] = None, # CrossEncoder instance
53
+ ):
54
+ """
55
+ Initialize SearchService.
56
+
57
+ Args:
58
+ collections: Dict mapping collection name to ChromaDBAdapter
59
+ scoring_service: ScoringService for ranking
60
+ routing_service: RoutingService for query routing
61
+ kg_service: KnowledgeGraphService for KG operations
62
+ embed_fn: Async function to generate embeddings
63
+ config: Optional MemoryConfig
64
+ reranker: Optional CrossEncoder for reranking
65
+ """
66
+ self.collections = collections
67
+ self.scoring_service = scoring_service
68
+ self.routing_service = routing_service
69
+ self.kg_service = kg_service
70
+ self.embed_fn = embed_fn
71
+ self.config = config or MemoryConfig()
72
+ self.reranker = reranker
73
+
74
+ # Cache for doc_ids per session (for outcome scoring)
75
+ self._cached_doc_ids: Dict[str, List[str]] = {}
76
+
77
+ # =========================================================================
78
+ # Main Search
79
+ # =========================================================================
80
+
81
+ async def search(
82
+ self,
83
+ query: str,
84
+ limit: int = 10,
85
+ offset: int = 0,
86
+ collections: Optional[List[CollectionName]] = None,
87
+ metadata_filters: Optional[Dict[str, Any]] = None,
88
+ return_metadata: bool = False,
89
+ transparency_context: Optional[Any] = None
90
+ ) -> Union[List[Dict], Dict]:
91
+ """
92
+ Search memory with intelligent routing and optional metadata filtering.
93
+
94
+ Args:
95
+ query: Search query
96
+ limit: Max results
97
+ offset: Pagination offset
98
+ collections: Override automatic routing
99
+ metadata_filters: ChromaDB where filters
100
+ return_metadata: Include pagination metadata
101
+ transparency_context: Optional context for tracking
102
+
103
+ Returns:
104
+ Ranked results (list or dict with pagination metadata)
105
+ """
106
+ # Use KG to route query if collections not specified
107
+ if collections is None:
108
+ collections = self.routing_service.route_query(query)
109
+
110
+ # Check for known problem->solution patterns
111
+ known_solutions = await self.kg_service.find_known_solutions(
112
+ query,
113
+ self._get_fragment
114
+ )
115
+
116
+ # Special handling for empty query - return all items
117
+ if not query or query.strip() == "":
118
+ return await self._search_all(
119
+ collections, limit, offset, return_metadata
120
+ )
121
+
122
+ # Preprocess query for better retrieval
123
+ processed_query = self.routing_service.preprocess_query(query)
124
+
125
+ # Generate query embedding
126
+ try:
127
+ query_embedding = await self.embed_fn(processed_query)
128
+ except Exception as e:
129
+ logger.error(f"Embedding generation failed for query '{query}': {e}")
130
+ if return_metadata:
131
+ return {"results": [], "total": 0, "limit": limit, "offset": offset, "has_more": False}
132
+ return []
133
+
134
+ # Track search start if context provided
135
+ if transparency_context and hasattr(transparency_context, 'track_action'):
136
+ transparency_context.track_action(
137
+ action_type="memory_search",
138
+ description=f"Searching: {query[:50]}{'...' if len(query) > 50 else ''}",
139
+ detail=f"Collections: {', '.join(collections)}",
140
+ status="executing"
141
+ )
142
+
143
+ # Search specified collections
144
+ all_results = await self._search_collections(
145
+ query_embedding, processed_query, collections, limit, metadata_filters
146
+ )
147
+
148
+ # Add known solutions to the beginning (they're already boosted)
149
+ if known_solutions:
150
+ existing_ids = {r.get("id") for r in all_results}
151
+ unique_known = [s for s in known_solutions if s.get("id") not in existing_ids]
152
+ all_results = unique_known + all_results
153
+
154
+ # Apply scoring and ranking
155
+ all_results = self.scoring_service.apply_scoring_to_results(all_results)
156
+
157
+ # Cross-encoder reranking
158
+ if self.reranker and len(all_results) > limit * 2:
159
+ all_results = await self._rerank_with_cross_encoder(query, all_results, limit)
160
+
161
+ # Track usage for KG learning
162
+ paginated_results = all_results[offset:offset + limit]
163
+ self._track_search_results(query, paginated_results, transparency_context)
164
+
165
+ # Return results
166
+ if return_metadata:
167
+ return {
168
+ "results": paginated_results,
169
+ "total": len(all_results),
170
+ "limit": limit,
171
+ "offset": offset,
172
+ "has_more": (offset + limit) < len(all_results)
173
+ }
174
+ return paginated_results
175
+
176
+ async def _search_all(
177
+ self,
178
+ collections: List[str],
179
+ limit: int,
180
+ offset: int,
181
+ return_metadata: bool
182
+ ) -> Union[List[Dict], Dict]:
183
+ """Handle empty query - return all items."""
184
+ all_results = []
185
+
186
+ for coll_name in collections:
187
+ if coll_name not in self.collections:
188
+ continue
189
+
190
+ try:
191
+ adapter = self.collections[coll_name]
192
+ collection_obj = adapter.collection
193
+ items = collection_obj.get(limit=100000)
194
+
195
+ for i in range(len(items['ids'])):
196
+ metadata = items['metadatas'][i] if i < len(items['metadatas']) else {}
197
+ result = {
198
+ 'id': items['ids'][i],
199
+ 'content': items['documents'][i] if i < len(items['documents']) else '',
200
+ 'text': items['documents'][i] if i < len(items['documents']) else '',
201
+ 'metadata': metadata,
202
+ 'collection': coll_name
203
+ }
204
+ if 'score' in metadata:
205
+ result['score'] = metadata['score']
206
+ all_results.append(result)
207
+ except Exception as e:
208
+ logger.error(f"Error getting all items from {coll_name}: {e}")
209
+
210
+ # Sort by timestamp
211
+ all_results.sort(key=lambda x: x.get('metadata', {}).get('timestamp', ''), reverse=True)
212
+
213
+ paginated_results = all_results[offset:offset + limit]
214
+ if return_metadata:
215
+ return {
216
+ "results": paginated_results,
217
+ "total": len(all_results),
218
+ "limit": limit,
219
+ "offset": offset,
220
+ "has_more": (offset + limit) < len(all_results)
221
+ }
222
+ return paginated_results
223
+
224
+ async def _search_collections(
225
+ self,
226
+ query_embedding: List[float],
227
+ processed_query: str,
228
+ collections: List[str],
229
+ limit: int,
230
+ metadata_filters: Optional[Dict[str, Any]]
231
+ ) -> List[Dict]:
232
+ """Search specified collections and apply boosts."""
233
+ all_results = []
234
+
235
+ for coll_name in collections:
236
+ if coll_name not in self.collections:
237
+ continue
238
+
239
+ results = await self._search_single_collection(
240
+ coll_name, query_embedding, processed_query, limit, metadata_filters
241
+ )
242
+
243
+ # Apply collection-specific boosts
244
+ for r in results:
245
+ r["collection"] = coll_name
246
+ self._apply_collection_boost(r, coll_name, processed_query)
247
+
248
+ all_results.extend(results)
249
+
250
+ return all_results
251
+
252
+ async def _search_single_collection(
253
+ self,
254
+ coll_name: str,
255
+ query_embedding: List[float],
256
+ processed_query: str,
257
+ limit: int,
258
+ metadata_filters: Optional[Dict[str, Any]]
259
+ ) -> List[Dict]:
260
+ """Search a single collection."""
261
+ adapter = self.collections[coll_name]
262
+ multiplier = self.config.search_multiplier
263
+
264
+ # Build filters for memory_bank
265
+ filters = metadata_filters
266
+ if coll_name == "memory_bank":
267
+ filters = (metadata_filters or {}).copy()
268
+ if "status" not in filters:
269
+ filters["status"] = {"$ne": "archived"}
270
+
271
+ # Hybrid query (vector + BM25)
272
+ results = await adapter.hybrid_query(
273
+ query_vector=query_embedding,
274
+ query_text=processed_query,
275
+ top_k=limit * multiplier,
276
+ filters=filters
277
+ )
278
+
279
+ # Add recency metadata for working memory
280
+ if coll_name == "working":
281
+ self._add_recency_metadata(results)
282
+
283
+ return results
284
+
285
+ def _apply_collection_boost(self, result: Dict, coll_name: str, query: str):
286
+ """Apply collection-specific distance boosts."""
287
+ # Patterns get slight boost
288
+ if coll_name == "patterns":
289
+ result["distance"] = result.get("distance", 1.0) * 0.9
290
+
291
+ # Memory bank: boost by importance * confidence
292
+ elif coll_name == "memory_bank":
293
+ metadata = result.get("metadata", {})
294
+ importance = self._parse_numeric(metadata.get("importance", 0.7))
295
+ confidence = self._parse_numeric(metadata.get("confidence", 0.7))
296
+ quality_score = importance * confidence
297
+
298
+ # Quality boost
299
+ metadata_boost = 1.0 - quality_score * 0.8
300
+ entity_boost = self._calculate_entity_boost(query, result.get("id", ""))
301
+ result["distance"] = result.get("distance", 1.0) * metadata_boost / entity_boost
302
+
303
+ # Doc effectiveness boost
304
+ doc_id = result.get("id") or result.get("doc_id")
305
+ if doc_id:
306
+ eff = self.get_doc_effectiveness(doc_id)
307
+ if eff and eff.get("total_uses", 0) >= 3:
308
+ eff_multiplier = 0.7 + eff["success_rate"] * 0.6
309
+ result["distance"] = result["distance"] / eff_multiplier
310
+
311
+ # Books: boost recent uploads
312
+ elif coll_name == "books":
313
+ if result.get("upload_timestamp"):
314
+ try:
315
+ upload_time = datetime.fromisoformat(result["upload_timestamp"])
316
+ age_days = (datetime.utcnow() - upload_time).days
317
+ if age_days <= 7:
318
+ result["distance"] = result.get("distance", 1.0) * 0.7
319
+ except Exception:
320
+ pass
321
+
322
+ # Doc effectiveness boost
323
+ doc_id = result.get("id") or result.get("doc_id")
324
+ if doc_id:
325
+ eff = self.get_doc_effectiveness(doc_id)
326
+ if eff and eff.get("total_uses", 0) >= 3:
327
+ eff_multiplier = 0.7 + eff["success_rate"] * 0.6
328
+ result["distance"] = result.get("distance", 1.0) / eff_multiplier
329
+
330
+ def _add_recency_metadata(self, results: List[Dict]):
331
+ """Add recency metadata to working memory results."""
332
+ for r in results:
333
+ metadata = r.get("metadata", {})
334
+ if metadata.get("timestamp"):
335
+ try:
336
+ timestamp = datetime.fromisoformat(metadata["timestamp"])
337
+ minutes_ago = (datetime.now() - timestamp).total_seconds() / 60
338
+
339
+ if minutes_ago < 1:
340
+ metadata["recency"] = "just now"
341
+ elif minutes_ago < 60:
342
+ metadata["recency"] = f"{int(minutes_ago)} minutes ago"
343
+ else:
344
+ hours_ago = minutes_ago / 60
345
+ metadata["recency"] = f"{int(hours_ago)} hours ago"
346
+
347
+ r["minutes_ago"] = minutes_ago
348
+ except Exception:
349
+ r["minutes_ago"] = 999
350
+
351
+ def _parse_numeric(self, value: Any) -> float:
352
+ """Parse value to float, handling various formats."""
353
+ if isinstance(value, (list, tuple)):
354
+ return float(value[0]) if value else 0.7
355
+ elif isinstance(value, str):
356
+ level_map = {'high': 0.9, 'medium': 0.7, 'low': 0.5}
357
+ return level_map.get(value.lower(), 0.7)
358
+ try:
359
+ return float(value) if value else 0.7
360
+ except (ValueError, TypeError):
361
+ return 0.7
362
+
363
+ # =========================================================================
364
+ # Entity Boost Calculation
365
+ # =========================================================================
366
+
367
+ def _calculate_entity_boost(self, query: str, doc_id: str) -> float:
368
+ """
369
+ Calculate quality boost based on Content KG entities.
370
+
371
+ Only applies to memory_bank searches - boosts documents containing
372
+ high-quality entities that match query concepts.
373
+
374
+ Returns:
375
+ Boost multiplier (1.0 = no boost, up to 1.5 = 50% boost)
376
+ """
377
+ try:
378
+ query_concepts = self.kg_service.extract_concepts(query)
379
+ query_entities = [c for c in query_concepts if len(c) >= 3]
380
+
381
+ if not query_entities:
382
+ return 1.0
383
+
384
+ doc_entities = self.kg_service.content_graph._doc_entities.get(doc_id, set())
385
+
386
+ if not doc_entities:
387
+ return 1.0
388
+
389
+ total_boost = 0.0
390
+ for entity in query_entities:
391
+ if entity in doc_entities and entity in self.kg_service.content_graph.entities:
392
+ entity_quality = self.kg_service.content_graph.entities[entity].get("avg_quality", 0.0)
393
+ total_boost += entity_quality
394
+
395
+ # Cap boost at 50%
396
+ boost_multiplier = 1.0 + min(total_boost * 0.2, 0.5)
397
+
398
+ if boost_multiplier > 1.0:
399
+ logger.debug(f"Entity boost for {doc_id}: {boost_multiplier:.2f}x")
400
+
401
+ return boost_multiplier
402
+ except Exception as e:
403
+ logger.error(f"Error calculating entity boost: {e}")
404
+ return 1.0
405
+
406
+ # =========================================================================
407
+ # Document Effectiveness
408
+ # =========================================================================
409
+
410
+ def get_doc_effectiveness(self, doc_id: str) -> Optional[Dict[str, Any]]:
411
+ """
412
+ Aggregate success rate for a specific doc from Action KG examples.
413
+
414
+ Returns:
415
+ Dict with success_rate, total_uses, etc., or None if no data
416
+ """
417
+ successes = 0
418
+ failures = 0
419
+ partials = 0
420
+
421
+ for key, stats in self.kg_service.knowledge_graph.get("context_action_effectiveness", {}).items():
422
+ for example in stats.get("examples", []):
423
+ if example.get("doc_id") == doc_id:
424
+ if example["outcome"] == "worked":
425
+ successes += 1
426
+ elif example["outcome"] == "failed":
427
+ failures += 1
428
+ else:
429
+ partials += 1
430
+
431
+ total = successes + failures + partials
432
+ if total == 0:
433
+ return None
434
+
435
+ return {
436
+ "successes": successes,
437
+ "failures": failures,
438
+ "partials": partials,
439
+ "total_uses": total,
440
+ "success_rate": (successes + partials * 0.5) / total
441
+ }
442
+
443
+ # =========================================================================
444
+ # Cross-Encoder Reranking
445
+ # =========================================================================
446
+
447
+ async def _rerank_with_cross_encoder(
448
+ self,
449
+ query: str,
450
+ candidates: List[Dict],
451
+ top_k: int
452
+ ) -> List[Dict]:
453
+ """
454
+ Rerank top candidates with cross-encoder for precision.
455
+
456
+ Returns:
457
+ Reranked results with cross-encoder scores
458
+ """
459
+ if not self.reranker:
460
+ return candidates
461
+
462
+ try:
463
+ # Take top-30 candidates for reranking
464
+ top_candidates = sorted(
465
+ candidates,
466
+ key=lambda x: x.get("final_rank_score", 0.0),
467
+ reverse=True
468
+ )[:30]
469
+
470
+ # Prepare query-document pairs
471
+ pairs = []
472
+ for candidate in top_candidates:
473
+ doc_text = candidate.get("text", "")
474
+ if not doc_text and candidate.get("metadata"):
475
+ doc_text = candidate.get("metadata", {}).get("content", "")
476
+ pairs.append([query, doc_text[:512]])
477
+
478
+ # Score with cross-encoder
479
+ ce_scores = self.reranker.predict(pairs, batch_size=32, show_progress_bar=False)
480
+
481
+ # Blend scores
482
+ for i, candidate in enumerate(top_candidates):
483
+ ce_score = float(ce_scores[i])
484
+ candidate["ce_score"] = ce_score
485
+ original_score = candidate.get("final_rank_score", 0.5)
486
+
487
+ collection = candidate.get("collection", "")
488
+ if collection == "memory_bank":
489
+ metadata = candidate.get("metadata", {})
490
+ importance = self._parse_numeric(metadata.get("importance", 0.7))
491
+ confidence = self._parse_numeric(metadata.get("confidence", 0.7))
492
+ quality = importance * confidence
493
+
494
+ ce_norm = (ce_score + 1) / 2
495
+ ce_weight = 0.3
496
+ quality_boost = 1.0 + quality * 0.3
497
+ blended = ((1 - ce_weight) * original_score + ce_weight * ce_norm) * quality_boost
498
+ else:
499
+ ce_norm = (ce_score + 1) / 2
500
+ ce_weight = 0.4
501
+ blended = (1 - ce_weight) * original_score + ce_weight * ce_norm
502
+
503
+ candidate["final_rank_score"] = blended
504
+
505
+ # Re-sort by updated score
506
+ top_candidates.sort(key=lambda x: x.get("final_rank_score", 0.0), reverse=True)
507
+
508
+ # Merge back: use reranked top + remaining
509
+ top_ids = {c.get("id") for c in top_candidates}
510
+ remaining = [c for c in candidates if c.get("id") not in top_ids]
511
+
512
+ return top_candidates + remaining
513
+
514
+ except Exception as e:
515
+ logger.error(f"Cross-encoder reranking failed: {e}")
516
+ return candidates
517
+
518
+ # =========================================================================
519
+ # Tracking and Caching
520
+ # =========================================================================
521
+
522
+ def _track_search_results(
523
+ self,
524
+ query: str,
525
+ results: List[Dict],
526
+ transparency_context: Optional[Any]
527
+ ):
528
+ """Track search results for KG learning."""
529
+ # Track usage for returned results
530
+ for result in results:
531
+ if "collection" in result and "id" in result:
532
+ self._track_usage(query, result["collection"], result["id"])
533
+
534
+ # Cache doc_ids for outcome scoring
535
+ session_id = 'default'
536
+ if transparency_context and hasattr(transparency_context, 'session_id'):
537
+ session_id = transparency_context.session_id
538
+
539
+ cached_doc_ids = []
540
+ for result in results:
541
+ collection = result.get("collection", "")
542
+ if collection in ["working", "history", "patterns"]:
543
+ doc_id = result.get("id")
544
+ if doc_id:
545
+ cached_doc_ids.append(doc_id)
546
+
547
+ self._cached_doc_ids[session_id] = cached_doc_ids
548
+ if cached_doc_ids:
549
+ logger.debug(f"Cached {len(cached_doc_ids)} doc_ids for outcome scoring")
550
+
551
+ # Track with transparency context
552
+ if transparency_context and hasattr(transparency_context, 'track_memory_search'):
553
+ confidence_scores = [math.exp(-r.get("distance", 0.5) / 100.0) for r in results]
554
+ transparency_context.track_memory_search(
555
+ query=query,
556
+ collections=[r.get("collection") for r in results[:1]],
557
+ results_count=len(results),
558
+ confidence_scores=confidence_scores
559
+ )
560
+
561
+ def _track_usage(self, query: str, collection: str, doc_id: str):
562
+ """Track which collection was used for which query."""
563
+ concepts = self.kg_service.extract_concepts(query)
564
+ for concept in concepts:
565
+ if concept not in self.kg_service.knowledge_graph["routing_patterns"]:
566
+ self.kg_service.knowledge_graph["routing_patterns"][concept] = {
567
+ "collections_used": {},
568
+ "best_collection": collection,
569
+ "success_rate": 0.5
570
+ }
571
+
572
+ def _get_fragment(self, collection: str, doc_id: str) -> Optional[Dict]:
573
+ """Get a document fragment by collection and ID."""
574
+ if collection in self.collections:
575
+ return self.collections[collection].get_fragment(doc_id)
576
+ return None
577
+
578
+ def get_cached_doc_ids(self, session_id: str = 'default') -> List[str]:
579
+ """Get cached doc_ids for a session."""
580
+ return self._cached_doc_ids.get(session_id, [])
581
+
582
+ # =========================================================================
583
+ # Book Search (Specialized)
584
+ # =========================================================================
585
+
586
+ async def search_books(
587
+ self,
588
+ query: str,
589
+ chunk_type: Optional[str] = None,
590
+ has_code: Optional[bool] = None,
591
+ code_language: Optional[str] = None,
592
+ n_results: int = 5
593
+ ) -> List[Dict]:
594
+ """
595
+ Enhanced book search with metadata filtering.
596
+
597
+ Args:
598
+ query: Search query text
599
+ chunk_type: Filter by chunk type ("code", "prose", "mixed")
600
+ has_code: Filter by presence of code
601
+ code_language: Filter by programming language
602
+ n_results: Number of results to return
603
+
604
+ Returns:
605
+ List of search results with enhanced context
606
+ """
607
+ if "books" not in self.collections:
608
+ logger.warning("Books collection not initialized")
609
+ return []
610
+
611
+ # Build where clause
612
+ where = {}
613
+ if chunk_type:
614
+ where["chunk_type"] = chunk_type
615
+ if has_code is not None:
616
+ where["has_code"] = has_code
617
+ if code_language:
618
+ where["code_language"] = code_language
619
+
620
+ try:
621
+ results = await self.collections["books"].query(
622
+ query_texts=[query],
623
+ n_results=n_results * 2,
624
+ where=where if where else None
625
+ )
626
+
627
+ if not results or not results.get("ids"):
628
+ return []
629
+
630
+ # Format results
631
+ formatted_results = []
632
+ for i in range(min(n_results, len(results["ids"][0]))):
633
+ result = {
634
+ "id": results["ids"][0][i],
635
+ "text": results["documents"][0][i],
636
+ "metadata": results["metadatas"][0][i] if results.get("metadatas") else {},
637
+ "distance": results["distances"][0][i] if results.get("distances") else 0,
638
+ "collection": "books"
639
+ }
640
+ formatted_results.append(result)
641
+
642
+ return formatted_results
643
+
644
+ except Exception as e:
645
+ logger.error(f"Book search failed: {e}")
646
+ return []
@@ -0,0 +1 @@
1
+ # Memory module tests
@@ -0,0 +1,12 @@
1
+ """
2
+ Pytest configuration for memory module tests.
3
+
4
+ Sets up import paths for roampal-core.
5
+ """
6
+
7
+ import sys
8
+ import os
9
+
10
+ # Add roampal-core root to path
11
+ CORE_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..', '..', '..'))
12
+ sys.path.insert(0, CORE_ROOT)
@@ -0,0 +1 @@
1
+ # Unit tests for memory module
@@ -0,0 +1,7 @@
1
+ """
2
+ Pytest configuration for unit tests.
3
+
4
+ Path setup is managed by root conftest.py - no duplicate needed here.
5
+ """
6
+
7
+ # Path setup handled by root conftest.py via pytest_runtest_setup hook