roampal 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- roampal/__init__.py +29 -0
- roampal/__main__.py +6 -0
- roampal/backend/__init__.py +1 -0
- roampal/backend/modules/__init__.py +1 -0
- roampal/backend/modules/memory/__init__.py +43 -0
- roampal/backend/modules/memory/chromadb_adapter.py +623 -0
- roampal/backend/modules/memory/config.py +102 -0
- roampal/backend/modules/memory/content_graph.py +543 -0
- roampal/backend/modules/memory/context_service.py +455 -0
- roampal/backend/modules/memory/embedding_service.py +96 -0
- roampal/backend/modules/memory/knowledge_graph_service.py +1052 -0
- roampal/backend/modules/memory/memory_bank_service.py +433 -0
- roampal/backend/modules/memory/memory_types.py +296 -0
- roampal/backend/modules/memory/outcome_service.py +400 -0
- roampal/backend/modules/memory/promotion_service.py +473 -0
- roampal/backend/modules/memory/routing_service.py +444 -0
- roampal/backend/modules/memory/scoring_service.py +324 -0
- roampal/backend/modules/memory/search_service.py +646 -0
- roampal/backend/modules/memory/tests/__init__.py +1 -0
- roampal/backend/modules/memory/tests/conftest.py +12 -0
- roampal/backend/modules/memory/tests/unit/__init__.py +1 -0
- roampal/backend/modules/memory/tests/unit/conftest.py +7 -0
- roampal/backend/modules/memory/tests/unit/test_knowledge_graph_service.py +517 -0
- roampal/backend/modules/memory/tests/unit/test_memory_bank_service.py +504 -0
- roampal/backend/modules/memory/tests/unit/test_outcome_service.py +485 -0
- roampal/backend/modules/memory/tests/unit/test_scoring_service.py +255 -0
- roampal/backend/modules/memory/tests/unit/test_search_service.py +413 -0
- roampal/backend/modules/memory/tests/unit/test_unified_memory_system.py +418 -0
- roampal/backend/modules/memory/unified_memory_system.py +1277 -0
- roampal/cli.py +638 -0
- roampal/hooks/__init__.py +16 -0
- roampal/hooks/session_manager.py +587 -0
- roampal/hooks/stop_hook.py +176 -0
- roampal/hooks/user_prompt_submit_hook.py +103 -0
- roampal/mcp/__init__.py +7 -0
- roampal/mcp/server.py +611 -0
- roampal/server/__init__.py +7 -0
- roampal/server/main.py +744 -0
- roampal-0.1.4.dist-info/METADATA +179 -0
- roampal-0.1.4.dist-info/RECORD +44 -0
- roampal-0.1.4.dist-info/WHEEL +5 -0
- roampal-0.1.4.dist-info/entry_points.txt +2 -0
- roampal-0.1.4.dist-info/licenses/LICENSE +190 -0
- roampal-0.1.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unit Tests for ScoringService
|
|
3
|
+
|
|
4
|
+
Tests the extracted scoring logic to ensure it matches the original behavior.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import sys
|
|
8
|
+
import os
|
|
9
|
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..', '..', '..', '..')))
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import pytest
|
|
13
|
+
|
|
14
|
+
from roampal.backend.modules.memory.scoring_service import ScoringService, wilson_score_lower
|
|
15
|
+
from roampal.backend.modules.memory.config import MemoryConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TestWilsonScoreLower:
|
|
19
|
+
"""Test the Wilson score function."""
|
|
20
|
+
|
|
21
|
+
def test_zero_uses_returns_neutral(self):
|
|
22
|
+
"""Zero uses should return 0.5 (neutral)."""
|
|
23
|
+
score = wilson_score_lower(0, 0)
|
|
24
|
+
assert score == 0.5
|
|
25
|
+
|
|
26
|
+
def test_perfect_record_low_samples(self):
|
|
27
|
+
"""1/1 should have lower score than proven record."""
|
|
28
|
+
new_score = wilson_score_lower(1, 1)
|
|
29
|
+
proven_score = wilson_score_lower(90, 100)
|
|
30
|
+
|
|
31
|
+
# Proven should beat newcomer
|
|
32
|
+
assert proven_score > new_score
|
|
33
|
+
# New score should be around 0.2
|
|
34
|
+
assert 0.1 < new_score < 0.4
|
|
35
|
+
# Proven score should be around 0.84
|
|
36
|
+
assert 0.8 < proven_score < 0.9
|
|
37
|
+
|
|
38
|
+
def test_score_range(self):
|
|
39
|
+
"""Score should always be between 0 and 1."""
|
|
40
|
+
test_cases = [
|
|
41
|
+
(0, 1), (1, 1), (5, 10), (50, 100),
|
|
42
|
+
(99, 100), (100, 100), (0, 100),
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
for successes, total in test_cases:
|
|
46
|
+
score = wilson_score_lower(successes, total)
|
|
47
|
+
assert 0 <= score <= 1, f"Score {score} out of range for {successes}/{total}"
|
|
48
|
+
|
|
49
|
+
def test_monotonic_increase(self):
|
|
50
|
+
"""Higher success rate should give higher score (same sample size)."""
|
|
51
|
+
scores = [wilson_score_lower(i, 10) for i in range(11)]
|
|
52
|
+
|
|
53
|
+
for i in range(1, len(scores)):
|
|
54
|
+
assert scores[i] >= scores[i-1], \
|
|
55
|
+
f"Score should increase: {scores[i-1]} -> {scores[i]}"
|
|
56
|
+
|
|
57
|
+
def test_sample_size_effect(self):
|
|
58
|
+
"""Same success rate with more samples should give higher score."""
|
|
59
|
+
score_10 = wilson_score_lower(5, 10)
|
|
60
|
+
score_100 = wilson_score_lower(50, 100)
|
|
61
|
+
score_1000 = wilson_score_lower(500, 1000)
|
|
62
|
+
|
|
63
|
+
assert score_100 > score_10
|
|
64
|
+
assert score_1000 > score_100
|
|
65
|
+
|
|
66
|
+
def test_matches_expected_values(self):
|
|
67
|
+
"""Should match expected mathematical values from Wilson formula."""
|
|
68
|
+
# Values verified against original unified_memory_system.py implementation
|
|
69
|
+
expected = {
|
|
70
|
+
(0, 0): 0.5,
|
|
71
|
+
(1, 1): 0.2065,
|
|
72
|
+
(90, 100): 0.8256, # Actual output from Wilson formula
|
|
73
|
+
(50, 100): 0.4038,
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
for (successes, total), expected_approx in expected.items():
|
|
77
|
+
new_score = wilson_score_lower(successes, total)
|
|
78
|
+
assert abs(new_score - expected_approx) < 0.01, \
|
|
79
|
+
f"Mismatch for {successes}/{total}: got {new_score}, expected ~{expected_approx}"
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class TestScoringService:
|
|
83
|
+
"""Test the ScoringService class."""
|
|
84
|
+
|
|
85
|
+
@pytest.fixture
|
|
86
|
+
def service(self):
|
|
87
|
+
return ScoringService()
|
|
88
|
+
|
|
89
|
+
@pytest.fixture
|
|
90
|
+
def custom_config_service(self):
|
|
91
|
+
config = MemoryConfig(
|
|
92
|
+
embedding_weight_proven=0.3,
|
|
93
|
+
learned_weight_proven=0.7,
|
|
94
|
+
)
|
|
95
|
+
return ScoringService(config)
|
|
96
|
+
|
|
97
|
+
def test_default_config(self, service):
|
|
98
|
+
"""Should use default config if none provided."""
|
|
99
|
+
assert service.config.high_value_threshold == 0.9
|
|
100
|
+
assert service.config.promotion_score_threshold == 0.7
|
|
101
|
+
|
|
102
|
+
def test_custom_config(self, custom_config_service):
|
|
103
|
+
"""Should use custom config if provided."""
|
|
104
|
+
assert custom_config_service.config.embedding_weight_proven == 0.3
|
|
105
|
+
assert custom_config_service.config.learned_weight_proven == 0.7
|
|
106
|
+
|
|
107
|
+
def test_count_successes_empty(self, service):
|
|
108
|
+
"""Empty history should return 0."""
|
|
109
|
+
assert service.count_successes_from_history("") == 0
|
|
110
|
+
assert service.count_successes_from_history("[]") == 0
|
|
111
|
+
|
|
112
|
+
def test_count_successes_worked(self, service):
|
|
113
|
+
"""Worked outcomes count as 1."""
|
|
114
|
+
history = json.dumps([{"outcome": "worked"}])
|
|
115
|
+
assert service.count_successes_from_history(history) == 1.0
|
|
116
|
+
|
|
117
|
+
def test_count_successes_partial(self, service):
|
|
118
|
+
"""Partial outcomes count as 0.5."""
|
|
119
|
+
history = json.dumps([{"outcome": "partial"}])
|
|
120
|
+
assert service.count_successes_from_history(history) == 0.5
|
|
121
|
+
|
|
122
|
+
def test_count_successes_mixed(self, service):
|
|
123
|
+
"""Mixed outcomes should sum correctly."""
|
|
124
|
+
history = json.dumps([
|
|
125
|
+
{"outcome": "worked"},
|
|
126
|
+
{"outcome": "partial"},
|
|
127
|
+
{"outcome": "failed"},
|
|
128
|
+
{"outcome": "worked"},
|
|
129
|
+
])
|
|
130
|
+
assert service.count_successes_from_history(history) == 2.5
|
|
131
|
+
|
|
132
|
+
def test_count_successes_invalid_json(self, service):
|
|
133
|
+
"""Invalid JSON should return 0."""
|
|
134
|
+
assert service.count_successes_from_history("invalid") == 0
|
|
135
|
+
|
|
136
|
+
def test_calculate_learned_score_no_uses(self, service):
|
|
137
|
+
"""No uses should return raw score."""
|
|
138
|
+
learned, wilson = service.calculate_learned_score(0.7, 0)
|
|
139
|
+
assert learned == 0.7
|
|
140
|
+
assert wilson == 0.5
|
|
141
|
+
|
|
142
|
+
def test_calculate_learned_score_with_history(self, service):
|
|
143
|
+
"""Should calculate from outcome history."""
|
|
144
|
+
history = json.dumps([
|
|
145
|
+
{"outcome": "worked"},
|
|
146
|
+
{"outcome": "worked"},
|
|
147
|
+
{"outcome": "partial"},
|
|
148
|
+
])
|
|
149
|
+
learned, wilson = service.calculate_learned_score(0.5, 3, history)
|
|
150
|
+
expected_wilson = wilson_score_lower(2.5, 3)
|
|
151
|
+
assert abs(wilson - expected_wilson) < 0.01
|
|
152
|
+
|
|
153
|
+
def test_dynamic_weights_proven(self, service):
|
|
154
|
+
"""Proven memories should have high learned weight."""
|
|
155
|
+
emb_w, learn_w = service.get_dynamic_weights(5, 0.85, "history")
|
|
156
|
+
assert emb_w == 0.2
|
|
157
|
+
assert learn_w == 0.8
|
|
158
|
+
|
|
159
|
+
def test_dynamic_weights_established(self, service):
|
|
160
|
+
"""Established memories should have moderately high learned weight."""
|
|
161
|
+
emb_w, learn_w = service.get_dynamic_weights(3, 0.75, "history")
|
|
162
|
+
assert emb_w == 0.25
|
|
163
|
+
assert learn_w == 0.75
|
|
164
|
+
|
|
165
|
+
def test_dynamic_weights_new(self, service):
|
|
166
|
+
"""New memories should favor embedding."""
|
|
167
|
+
emb_w, learn_w = service.get_dynamic_weights(0, 0.5, "history")
|
|
168
|
+
assert emb_w > learn_w
|
|
169
|
+
|
|
170
|
+
def test_dynamic_weights_failing(self, service):
|
|
171
|
+
"""Failing memories should heavily favor embedding."""
|
|
172
|
+
emb_w, learn_w = service.get_dynamic_weights(3, 0.3, "history")
|
|
173
|
+
assert emb_w == 0.7
|
|
174
|
+
assert learn_w == 0.3
|
|
175
|
+
|
|
176
|
+
def test_dynamic_weights_memory_bank_high_quality(self, service):
|
|
177
|
+
"""High quality memory_bank should use quality-based weights."""
|
|
178
|
+
emb_w, learn_w = service.get_dynamic_weights(
|
|
179
|
+
0, 0.5, "memory_bank", importance=0.95, confidence=0.9
|
|
180
|
+
)
|
|
181
|
+
assert emb_w == 0.45
|
|
182
|
+
assert learn_w == 0.55
|
|
183
|
+
|
|
184
|
+
def test_dynamic_weights_memory_bank_normal(self, service):
|
|
185
|
+
"""Normal memory_bank should be balanced."""
|
|
186
|
+
emb_w, learn_w = service.get_dynamic_weights(
|
|
187
|
+
0, 0.5, "memory_bank", importance=0.7, confidence=0.7
|
|
188
|
+
)
|
|
189
|
+
assert emb_w == 0.5
|
|
190
|
+
assert learn_w == 0.5
|
|
191
|
+
|
|
192
|
+
def test_calculate_final_score_basic(self, service):
|
|
193
|
+
"""Should calculate final score correctly."""
|
|
194
|
+
metadata = {"score": 0.7, "uses": 0}
|
|
195
|
+
result = service.calculate_final_score(metadata, distance=1.0, collection="history")
|
|
196
|
+
|
|
197
|
+
assert "final_rank_score" in result
|
|
198
|
+
assert "wilson_score" in result
|
|
199
|
+
assert "embedding_similarity" in result
|
|
200
|
+
assert "learned_score" in result
|
|
201
|
+
assert 0 <= result["final_rank_score"] <= 1
|
|
202
|
+
|
|
203
|
+
def test_calculate_final_score_distance_conversion(self, service):
|
|
204
|
+
"""Distance should be converted to similarity correctly."""
|
|
205
|
+
metadata = {"score": 0.5, "uses": 0}
|
|
206
|
+
|
|
207
|
+
result0 = service.calculate_final_score(metadata, distance=0.0, collection="history")
|
|
208
|
+
assert result0["embedding_similarity"] == 1.0
|
|
209
|
+
|
|
210
|
+
result1 = service.calculate_final_score(metadata, distance=1.0, collection="history")
|
|
211
|
+
assert result1["embedding_similarity"] == 0.5
|
|
212
|
+
|
|
213
|
+
result10 = service.calculate_final_score(metadata, distance=10.0, collection="history")
|
|
214
|
+
assert result10["embedding_similarity"] < result1["embedding_similarity"]
|
|
215
|
+
|
|
216
|
+
def test_apply_scoring_to_results(self, service):
|
|
217
|
+
"""Should apply scoring and sort results."""
|
|
218
|
+
results = [
|
|
219
|
+
{"metadata": {"score": 0.3, "uses": 0}, "distance": 2.0, "collection": "history"},
|
|
220
|
+
{"metadata": {"score": 0.9, "uses": 5}, "distance": 0.5, "collection": "history"},
|
|
221
|
+
{"metadata": {"score": 0.5, "uses": 2}, "distance": 1.0, "collection": "history"},
|
|
222
|
+
]
|
|
223
|
+
|
|
224
|
+
scored = service.apply_scoring_to_results(results)
|
|
225
|
+
|
|
226
|
+
scores = [r["final_rank_score"] for r in scored]
|
|
227
|
+
assert scores == sorted(scores, reverse=True)
|
|
228
|
+
|
|
229
|
+
for r in scored:
|
|
230
|
+
assert "final_rank_score" in r
|
|
231
|
+
assert "wilson_score" in r
|
|
232
|
+
assert "embedding_weight" in r
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class TestScoringConsistency:
|
|
236
|
+
"""Test that new scoring matches original behavior."""
|
|
237
|
+
|
|
238
|
+
@pytest.fixture
|
|
239
|
+
def service(self):
|
|
240
|
+
return ScoringService()
|
|
241
|
+
|
|
242
|
+
def test_memory_bank_quality_scoring(self, service):
|
|
243
|
+
"""Memory bank should use importance*confidence as learned score."""
|
|
244
|
+
metadata = {
|
|
245
|
+
"score": 0.5,
|
|
246
|
+
"uses": 0,
|
|
247
|
+
"importance": 0.9,
|
|
248
|
+
"confidence": 0.8,
|
|
249
|
+
}
|
|
250
|
+
result = service.calculate_final_score(metadata, distance=0.5, collection="memory_bank")
|
|
251
|
+
assert abs(result["learned_score"] - 0.72) < 0.01
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
if __name__ == "__main__":
|
|
255
|
+
pytest.main([__file__, "-v"])
|
|
@@ -0,0 +1,413 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unit Tests for SearchService
|
|
3
|
+
|
|
4
|
+
Tests the extracted search logic.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import sys
|
|
8
|
+
import os
|
|
9
|
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..', '..', '..', '..')))
|
|
10
|
+
|
|
11
|
+
import pytest
|
|
12
|
+
from unittest.mock import MagicMock, AsyncMock, patch
|
|
13
|
+
from datetime import datetime, timedelta
|
|
14
|
+
|
|
15
|
+
from roampal.backend.modules.memory.search_service import SearchService
|
|
16
|
+
from roampal.backend.modules.memory.scoring_service import ScoringService
|
|
17
|
+
from roampal.backend.modules.memory.routing_service import RoutingService
|
|
18
|
+
from roampal.backend.modules.memory.knowledge_graph_service import KnowledgeGraphService
|
|
19
|
+
from roampal.backend.modules.memory.config import MemoryConfig
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class TestSearchServiceInit:
|
|
23
|
+
"""Test SearchService initialization."""
|
|
24
|
+
|
|
25
|
+
@pytest.fixture
|
|
26
|
+
def mock_dependencies(self):
|
|
27
|
+
"""Create mock dependencies."""
|
|
28
|
+
collections = {
|
|
29
|
+
"working": MagicMock(),
|
|
30
|
+
"history": MagicMock(),
|
|
31
|
+
"patterns": MagicMock(),
|
|
32
|
+
"books": MagicMock(),
|
|
33
|
+
"memory_bank": MagicMock(),
|
|
34
|
+
}
|
|
35
|
+
scoring = MagicMock(spec=ScoringService)
|
|
36
|
+
routing = MagicMock(spec=RoutingService)
|
|
37
|
+
kg = MagicMock(spec=KnowledgeGraphService)
|
|
38
|
+
kg.knowledge_graph = {"routing_patterns": {}, "context_action_effectiveness": {}}
|
|
39
|
+
embed_fn = AsyncMock(return_value=[0.1] * 384)
|
|
40
|
+
|
|
41
|
+
return {
|
|
42
|
+
"collections": collections,
|
|
43
|
+
"scoring_service": scoring,
|
|
44
|
+
"routing_service": routing,
|
|
45
|
+
"kg_service": kg,
|
|
46
|
+
"embed_fn": embed_fn,
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
def test_init_with_all_dependencies(self, mock_dependencies):
|
|
50
|
+
"""Should initialize with all dependencies."""
|
|
51
|
+
service = SearchService(**mock_dependencies)
|
|
52
|
+
assert service.collections == mock_dependencies["collections"]
|
|
53
|
+
assert service.scoring_service == mock_dependencies["scoring_service"]
|
|
54
|
+
assert service.routing_service == mock_dependencies["routing_service"]
|
|
55
|
+
assert service.kg_service == mock_dependencies["kg_service"]
|
|
56
|
+
|
|
57
|
+
def test_init_with_optional_reranker(self, mock_dependencies):
|
|
58
|
+
"""Should accept optional reranker."""
|
|
59
|
+
mock_reranker = MagicMock()
|
|
60
|
+
service = SearchService(**mock_dependencies, reranker=mock_reranker)
|
|
61
|
+
assert service.reranker == mock_reranker
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class TestMainSearch:
|
|
65
|
+
"""Test main search functionality."""
|
|
66
|
+
|
|
67
|
+
@pytest.fixture
|
|
68
|
+
def mock_service(self):
|
|
69
|
+
"""Create SearchService with mocks."""
|
|
70
|
+
collections = {
|
|
71
|
+
"working": MagicMock(),
|
|
72
|
+
"history": MagicMock(),
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
# Mock hybrid_query to return sample results
|
|
76
|
+
async def mock_hybrid_query(**kwargs):
|
|
77
|
+
return [
|
|
78
|
+
{"id": "doc_1", "text": "test result 1", "distance": 0.5, "metadata": {"score": 0.7, "uses": 3}},
|
|
79
|
+
{"id": "doc_2", "text": "test result 2", "distance": 0.8, "metadata": {"score": 0.5, "uses": 1}},
|
|
80
|
+
]
|
|
81
|
+
|
|
82
|
+
for coll in collections.values():
|
|
83
|
+
coll.hybrid_query = AsyncMock(side_effect=mock_hybrid_query)
|
|
84
|
+
|
|
85
|
+
scoring = MagicMock(spec=ScoringService)
|
|
86
|
+
scoring.apply_scoring_to_results = MagicMock(side_effect=lambda x: x)
|
|
87
|
+
|
|
88
|
+
routing = MagicMock(spec=RoutingService)
|
|
89
|
+
routing.route_query = MagicMock(return_value=["working", "history"])
|
|
90
|
+
routing.preprocess_query = MagicMock(side_effect=lambda x: x)
|
|
91
|
+
|
|
92
|
+
kg = MagicMock(spec=KnowledgeGraphService)
|
|
93
|
+
kg.knowledge_graph = {"routing_patterns": {}, "context_action_effectiveness": {}}
|
|
94
|
+
kg.find_known_solutions = AsyncMock(return_value=[])
|
|
95
|
+
kg.extract_concepts = MagicMock(return_value=["test"])
|
|
96
|
+
kg.content_graph = MagicMock()
|
|
97
|
+
kg.content_graph._doc_entities = {}
|
|
98
|
+
|
|
99
|
+
embed_fn = AsyncMock(return_value=[0.1] * 384)
|
|
100
|
+
|
|
101
|
+
return SearchService(
|
|
102
|
+
collections=collections,
|
|
103
|
+
scoring_service=scoring,
|
|
104
|
+
routing_service=routing,
|
|
105
|
+
kg_service=kg,
|
|
106
|
+
embed_fn=embed_fn,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
@pytest.mark.asyncio
|
|
110
|
+
async def test_search_routes_query(self, mock_service):
|
|
111
|
+
"""Should route query when collections not specified."""
|
|
112
|
+
await mock_service.search("test query", limit=5)
|
|
113
|
+
mock_service.routing_service.route_query.assert_called_once_with("test query")
|
|
114
|
+
|
|
115
|
+
@pytest.mark.asyncio
|
|
116
|
+
async def test_search_uses_explicit_collections(self, mock_service):
|
|
117
|
+
"""Should use explicit collections when provided."""
|
|
118
|
+
await mock_service.search("test query", collections=["history"], limit=5)
|
|
119
|
+
mock_service.routing_service.route_query.assert_not_called()
|
|
120
|
+
|
|
121
|
+
@pytest.mark.asyncio
|
|
122
|
+
async def test_search_preprocesses_query(self, mock_service):
|
|
123
|
+
"""Should preprocess query before embedding."""
|
|
124
|
+
await mock_service.search("test query", limit=5)
|
|
125
|
+
mock_service.routing_service.preprocess_query.assert_called()
|
|
126
|
+
|
|
127
|
+
@pytest.mark.asyncio
|
|
128
|
+
async def test_search_generates_embedding(self, mock_service):
|
|
129
|
+
"""Should generate embedding for query."""
|
|
130
|
+
await mock_service.search("test query", limit=5)
|
|
131
|
+
mock_service.embed_fn.assert_called()
|
|
132
|
+
|
|
133
|
+
@pytest.mark.asyncio
|
|
134
|
+
async def test_search_applies_scoring(self, mock_service):
|
|
135
|
+
"""Should apply scoring to results."""
|
|
136
|
+
await mock_service.search("test query", limit=5)
|
|
137
|
+
mock_service.scoring_service.apply_scoring_to_results.assert_called()
|
|
138
|
+
|
|
139
|
+
@pytest.mark.asyncio
|
|
140
|
+
async def test_search_returns_list_by_default(self, mock_service):
|
|
141
|
+
"""Should return list when return_metadata=False."""
|
|
142
|
+
result = await mock_service.search("test query", limit=5)
|
|
143
|
+
assert isinstance(result, list)
|
|
144
|
+
|
|
145
|
+
@pytest.mark.asyncio
|
|
146
|
+
async def test_search_returns_dict_with_metadata(self, mock_service):
|
|
147
|
+
"""Should return dict when return_metadata=True."""
|
|
148
|
+
result = await mock_service.search("test query", limit=5, return_metadata=True)
|
|
149
|
+
assert isinstance(result, dict)
|
|
150
|
+
assert "results" in result
|
|
151
|
+
assert "total" in result
|
|
152
|
+
assert "has_more" in result
|
|
153
|
+
|
|
154
|
+
@pytest.mark.asyncio
|
|
155
|
+
async def test_search_respects_limit(self, mock_service):
|
|
156
|
+
"""Should respect limit parameter."""
|
|
157
|
+
result = await mock_service.search("test query", limit=1)
|
|
158
|
+
assert len(result) <= 1
|
|
159
|
+
|
|
160
|
+
@pytest.mark.asyncio
|
|
161
|
+
async def test_search_handles_empty_query(self, mock_service):
|
|
162
|
+
"""Empty query should return all items."""
|
|
163
|
+
# Mock get for empty query path
|
|
164
|
+
for coll in mock_service.collections.values():
|
|
165
|
+
coll.collection = MagicMock()
|
|
166
|
+
coll.collection.get = MagicMock(return_value={
|
|
167
|
+
'ids': ['id1', 'id2'],
|
|
168
|
+
'documents': ['doc1', 'doc2'],
|
|
169
|
+
'metadatas': [{'score': 0.5}, {'score': 0.6}]
|
|
170
|
+
})
|
|
171
|
+
|
|
172
|
+
result = await mock_service.search("", limit=10)
|
|
173
|
+
assert len(result) > 0
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class TestEntityBoost:
|
|
177
|
+
"""Test entity boost calculation."""
|
|
178
|
+
|
|
179
|
+
@pytest.fixture
|
|
180
|
+
def mock_service(self):
|
|
181
|
+
"""Create SearchService with entity mocks."""
|
|
182
|
+
kg = MagicMock(spec=KnowledgeGraphService)
|
|
183
|
+
kg.extract_concepts = MagicMock(return_value=["python", "django"])
|
|
184
|
+
kg.content_graph = MagicMock()
|
|
185
|
+
kg.content_graph._doc_entities = {
|
|
186
|
+
"doc_1": {"python", "django", "web"},
|
|
187
|
+
}
|
|
188
|
+
kg.content_graph.entities = {
|
|
189
|
+
"python": {"avg_quality": 0.9},
|
|
190
|
+
"django": {"avg_quality": 0.8},
|
|
191
|
+
}
|
|
192
|
+
kg.knowledge_graph = {"routing_patterns": {}, "context_action_effectiveness": {}}
|
|
193
|
+
|
|
194
|
+
return SearchService(
|
|
195
|
+
collections={},
|
|
196
|
+
scoring_service=MagicMock(),
|
|
197
|
+
routing_service=MagicMock(),
|
|
198
|
+
kg_service=kg,
|
|
199
|
+
embed_fn=AsyncMock(),
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
def test_entity_boost_with_matches(self, mock_service):
|
|
203
|
+
"""Should boost documents with matching high-quality entities."""
|
|
204
|
+
boost = mock_service._calculate_entity_boost("python django", "doc_1")
|
|
205
|
+
# python (0.9) + django (0.8) = 1.7 quality
|
|
206
|
+
# boost = 1.0 + min(1.7 * 0.2, 0.5) = 1.0 + 0.34 = 1.34
|
|
207
|
+
assert boost > 1.0
|
|
208
|
+
assert boost <= 1.5
|
|
209
|
+
|
|
210
|
+
def test_entity_boost_no_matches(self, mock_service):
|
|
211
|
+
"""Should return 1.0 when no entity matches."""
|
|
212
|
+
boost = mock_service._calculate_entity_boost("python django", "doc_unknown")
|
|
213
|
+
assert boost == 1.0
|
|
214
|
+
|
|
215
|
+
def test_entity_boost_empty_query(self, mock_service):
|
|
216
|
+
"""Should return 1.0 for empty query."""
|
|
217
|
+
mock_service.kg_service.extract_concepts = MagicMock(return_value=[])
|
|
218
|
+
boost = mock_service._calculate_entity_boost("", "doc_1")
|
|
219
|
+
assert boost == 1.0
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
class TestDocEffectiveness:
|
|
223
|
+
"""Test document effectiveness calculation."""
|
|
224
|
+
|
|
225
|
+
@pytest.fixture
|
|
226
|
+
def mock_service(self):
|
|
227
|
+
"""Create SearchService with effectiveness data."""
|
|
228
|
+
kg = MagicMock(spec=KnowledgeGraphService)
|
|
229
|
+
kg.knowledge_graph = {
|
|
230
|
+
"routing_patterns": {},
|
|
231
|
+
"context_action_effectiveness": {
|
|
232
|
+
"context|action|coll": {
|
|
233
|
+
"examples": [
|
|
234
|
+
{"doc_id": "doc_1", "outcome": "worked"},
|
|
235
|
+
{"doc_id": "doc_1", "outcome": "worked"},
|
|
236
|
+
{"doc_id": "doc_1", "outcome": "failed"},
|
|
237
|
+
{"doc_id": "doc_2", "outcome": "partial"},
|
|
238
|
+
]
|
|
239
|
+
}
|
|
240
|
+
}
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
return SearchService(
|
|
244
|
+
collections={},
|
|
245
|
+
scoring_service=MagicMock(),
|
|
246
|
+
routing_service=MagicMock(),
|
|
247
|
+
kg_service=kg,
|
|
248
|
+
embed_fn=AsyncMock(),
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
def test_doc_effectiveness_calculates_rate(self, mock_service):
|
|
252
|
+
"""Should calculate success rate correctly."""
|
|
253
|
+
eff = mock_service.get_doc_effectiveness("doc_1")
|
|
254
|
+
assert eff is not None
|
|
255
|
+
assert eff["successes"] == 2
|
|
256
|
+
assert eff["failures"] == 1
|
|
257
|
+
assert eff["total_uses"] == 3
|
|
258
|
+
# success_rate = (2 + 0) / 3 = 0.667
|
|
259
|
+
assert abs(eff["success_rate"] - 0.667) < 0.01
|
|
260
|
+
|
|
261
|
+
def test_doc_effectiveness_unknown_doc(self, mock_service):
|
|
262
|
+
"""Should return None for unknown document."""
|
|
263
|
+
eff = mock_service.get_doc_effectiveness("unknown_doc")
|
|
264
|
+
assert eff is None
|
|
265
|
+
|
|
266
|
+
def test_doc_effectiveness_partial_counts(self, mock_service):
|
|
267
|
+
"""Should count partial as 0.5 success."""
|
|
268
|
+
eff = mock_service.get_doc_effectiveness("doc_2")
|
|
269
|
+
assert eff is not None
|
|
270
|
+
assert eff["partials"] == 1
|
|
271
|
+
# success_rate = (0 + 1*0.5) / 1 = 0.5
|
|
272
|
+
assert eff["success_rate"] == 0.5
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class TestCollectionBoosts:
|
|
276
|
+
"""Test collection-specific distance boosts."""
|
|
277
|
+
|
|
278
|
+
@pytest.fixture
|
|
279
|
+
def service(self):
|
|
280
|
+
"""Create SearchService."""
|
|
281
|
+
kg = MagicMock(spec=KnowledgeGraphService)
|
|
282
|
+
kg.extract_concepts = MagicMock(return_value=[])
|
|
283
|
+
kg.content_graph = MagicMock()
|
|
284
|
+
kg.content_graph._doc_entities = {}
|
|
285
|
+
kg.content_graph.entities = {}
|
|
286
|
+
kg.knowledge_graph = {"routing_patterns": {}, "context_action_effectiveness": {}}
|
|
287
|
+
|
|
288
|
+
return SearchService(
|
|
289
|
+
collections={},
|
|
290
|
+
scoring_service=MagicMock(),
|
|
291
|
+
routing_service=MagicMock(),
|
|
292
|
+
kg_service=kg,
|
|
293
|
+
embed_fn=AsyncMock(),
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
def test_patterns_boost(self, service):
|
|
297
|
+
"""Patterns should get 10% distance reduction."""
|
|
298
|
+
result = {"distance": 1.0, "metadata": {}}
|
|
299
|
+
service._apply_collection_boost(result, "patterns", "query")
|
|
300
|
+
assert result["distance"] == 0.9
|
|
301
|
+
|
|
302
|
+
def test_memory_bank_quality_boost(self, service):
|
|
303
|
+
"""Memory bank should boost by quality score."""
|
|
304
|
+
result = {
|
|
305
|
+
"distance": 1.0,
|
|
306
|
+
"id": "doc_1",
|
|
307
|
+
"metadata": {"importance": 0.9, "confidence": 0.9}
|
|
308
|
+
}
|
|
309
|
+
service._apply_collection_boost(result, "memory_bank", "query")
|
|
310
|
+
# quality = 0.81, metadata_boost = 1.0 - 0.81*0.8 = 0.352
|
|
311
|
+
assert result["distance"] < 1.0
|
|
312
|
+
|
|
313
|
+
def test_books_recent_upload_boost(self, service):
|
|
314
|
+
"""Recent books should get boost."""
|
|
315
|
+
result = {
|
|
316
|
+
"distance": 1.0,
|
|
317
|
+
"upload_timestamp": datetime.utcnow().isoformat(),
|
|
318
|
+
"metadata": {}
|
|
319
|
+
}
|
|
320
|
+
service._apply_collection_boost(result, "books", "query")
|
|
321
|
+
assert result["distance"] == 0.7
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
class TestCaching:
|
|
325
|
+
"""Test doc_id caching for outcome scoring."""
|
|
326
|
+
|
|
327
|
+
@pytest.fixture
|
|
328
|
+
def service(self):
|
|
329
|
+
"""Create SearchService."""
|
|
330
|
+
kg = MagicMock(spec=KnowledgeGraphService)
|
|
331
|
+
kg.extract_concepts = MagicMock(return_value=["test"])
|
|
332
|
+
kg.knowledge_graph = {"routing_patterns": {}}
|
|
333
|
+
|
|
334
|
+
return SearchService(
|
|
335
|
+
collections={},
|
|
336
|
+
scoring_service=MagicMock(),
|
|
337
|
+
routing_service=MagicMock(),
|
|
338
|
+
kg_service=kg,
|
|
339
|
+
embed_fn=AsyncMock(),
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
def test_track_search_caches_doc_ids(self, service):
|
|
343
|
+
"""Should cache doc_ids from scoreable collections."""
|
|
344
|
+
results = [
|
|
345
|
+
{"id": "working_1", "collection": "working"},
|
|
346
|
+
{"id": "history_1", "collection": "history"},
|
|
347
|
+
{"id": "books_1", "collection": "books"}, # Not cached
|
|
348
|
+
]
|
|
349
|
+
|
|
350
|
+
service._track_search_results("query", results, None)
|
|
351
|
+
|
|
352
|
+
cached = service.get_cached_doc_ids('default')
|
|
353
|
+
assert "working_1" in cached
|
|
354
|
+
assert "history_1" in cached
|
|
355
|
+
assert "books_1" not in cached # Books not cached
|
|
356
|
+
|
|
357
|
+
def test_caching_per_session(self, service):
|
|
358
|
+
"""Should cache separately per session."""
|
|
359
|
+
results1 = [{"id": "doc_1", "collection": "working"}]
|
|
360
|
+
results2 = [{"id": "doc_2", "collection": "working"}]
|
|
361
|
+
|
|
362
|
+
ctx1 = MagicMock()
|
|
363
|
+
ctx1.session_id = "session_1"
|
|
364
|
+
ctx2 = MagicMock()
|
|
365
|
+
ctx2.session_id = "session_2"
|
|
366
|
+
|
|
367
|
+
service._track_search_results("q1", results1, ctx1)
|
|
368
|
+
service._track_search_results("q2", results2, ctx2)
|
|
369
|
+
|
|
370
|
+
assert service.get_cached_doc_ids("session_1") == ["doc_1"]
|
|
371
|
+
assert service.get_cached_doc_ids("session_2") == ["doc_2"]
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
class TestParseNumeric:
|
|
375
|
+
"""Test numeric parsing helper."""
|
|
376
|
+
|
|
377
|
+
@pytest.fixture
|
|
378
|
+
def service(self):
|
|
379
|
+
return SearchService(
|
|
380
|
+
collections={},
|
|
381
|
+
scoring_service=MagicMock(),
|
|
382
|
+
routing_service=MagicMock(),
|
|
383
|
+
kg_service=MagicMock(),
|
|
384
|
+
embed_fn=AsyncMock(),
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
def test_parse_float(self, service):
|
|
388
|
+
assert service._parse_numeric(0.9) == 0.9
|
|
389
|
+
|
|
390
|
+
def test_parse_int(self, service):
|
|
391
|
+
assert service._parse_numeric(1) == 1.0
|
|
392
|
+
|
|
393
|
+
def test_parse_list(self, service):
|
|
394
|
+
assert service._parse_numeric([0.8, 0.9]) == 0.8
|
|
395
|
+
|
|
396
|
+
def test_parse_string_high(self, service):
|
|
397
|
+
assert service._parse_numeric("high") == 0.9
|
|
398
|
+
|
|
399
|
+
def test_parse_string_medium(self, service):
|
|
400
|
+
assert service._parse_numeric("medium") == 0.7
|
|
401
|
+
|
|
402
|
+
def test_parse_string_low(self, service):
|
|
403
|
+
assert service._parse_numeric("low") == 0.5
|
|
404
|
+
|
|
405
|
+
def test_parse_none(self, service):
|
|
406
|
+
assert service._parse_numeric(None) == 0.7
|
|
407
|
+
|
|
408
|
+
def test_parse_invalid(self, service):
|
|
409
|
+
assert service._parse_numeric("invalid") == 0.7
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
if __name__ == "__main__":
|
|
413
|
+
pytest.main([__file__, "-v"])
|