alma-memory 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alma/__init__.py +75 -0
- alma/config/__init__.py +5 -0
- alma/config/loader.py +156 -0
- alma/core.py +322 -0
- alma/harness/__init__.py +35 -0
- alma/harness/base.py +377 -0
- alma/harness/domains.py +689 -0
- alma/integration/__init__.py +62 -0
- alma/integration/claude_agents.py +432 -0
- alma/integration/helena.py +413 -0
- alma/integration/victor.py +447 -0
- alma/learning/__init__.py +86 -0
- alma/learning/forgetting.py +1396 -0
- alma/learning/heuristic_extractor.py +374 -0
- alma/learning/protocols.py +326 -0
- alma/learning/validation.py +341 -0
- alma/mcp/__init__.py +45 -0
- alma/mcp/__main__.py +155 -0
- alma/mcp/resources.py +121 -0
- alma/mcp/server.py +533 -0
- alma/mcp/tools.py +374 -0
- alma/retrieval/__init__.py +53 -0
- alma/retrieval/cache.py +1062 -0
- alma/retrieval/embeddings.py +202 -0
- alma/retrieval/engine.py +287 -0
- alma/retrieval/scoring.py +334 -0
- alma/storage/__init__.py +20 -0
- alma/storage/azure_cosmos.py +972 -0
- alma/storage/base.py +372 -0
- alma/storage/file_based.py +583 -0
- alma/storage/sqlite_local.py +912 -0
- alma/types.py +216 -0
- alma_memory-0.2.0.dist-info/METADATA +327 -0
- alma_memory-0.2.0.dist-info/RECORD +36 -0
- alma_memory-0.2.0.dist-info/WHEEL +5 -0
- alma_memory-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ALMA Embedding Providers.
|
|
3
|
+
|
|
4
|
+
Supports local (sentence-transformers) and Azure OpenAI embeddings.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from typing import List, Optional
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class EmbeddingProvider(ABC):
|
|
15
|
+
"""Abstract base class for embedding providers."""
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def encode(self, text: str) -> List[float]:
|
|
19
|
+
"""Generate embedding for text."""
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def encode_batch(self, texts: List[str]) -> List[List[float]]:
|
|
24
|
+
"""Generate embeddings for multiple texts."""
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def dimension(self) -> int:
|
|
30
|
+
"""Return embedding dimension."""
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class LocalEmbedder(EmbeddingProvider):
|
|
35
|
+
"""
|
|
36
|
+
Local embeddings using sentence-transformers.
|
|
37
|
+
|
|
38
|
+
Default model: all-MiniLM-L6-v2 (384 dimensions, fast, good quality)
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
|
|
42
|
+
"""
|
|
43
|
+
Initialize local embedder.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
model_name: Sentence-transformers model name
|
|
47
|
+
"""
|
|
48
|
+
self.model_name = model_name
|
|
49
|
+
self._model = None
|
|
50
|
+
self._dimension: Optional[int] = None
|
|
51
|
+
|
|
52
|
+
def _load_model(self):
|
|
53
|
+
"""Lazy load the model."""
|
|
54
|
+
if self._model is None:
|
|
55
|
+
try:
|
|
56
|
+
from sentence_transformers import SentenceTransformer
|
|
57
|
+
|
|
58
|
+
logger.info(f"Loading embedding model: {self.model_name}")
|
|
59
|
+
self._model = SentenceTransformer(self.model_name)
|
|
60
|
+
self._dimension = self._model.get_sentence_embedding_dimension()
|
|
61
|
+
logger.info(f"Model loaded, dimension: {self._dimension}")
|
|
62
|
+
except ImportError:
|
|
63
|
+
raise ImportError(
|
|
64
|
+
"sentence-transformers is required for local embeddings. "
|
|
65
|
+
"Install with: pip install sentence-transformers"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def encode(self, text: str) -> List[float]:
|
|
69
|
+
"""Generate embedding for text."""
|
|
70
|
+
self._load_model()
|
|
71
|
+
embedding = self._model.encode(text, normalize_embeddings=True)
|
|
72
|
+
return embedding.tolist()
|
|
73
|
+
|
|
74
|
+
def encode_batch(self, texts: List[str]) -> List[List[float]]:
|
|
75
|
+
"""Generate embeddings for multiple texts."""
|
|
76
|
+
self._load_model()
|
|
77
|
+
embeddings = self._model.encode(texts, normalize_embeddings=True)
|
|
78
|
+
return [emb.tolist() for emb in embeddings]
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def dimension(self) -> int:
|
|
82
|
+
"""Return embedding dimension."""
|
|
83
|
+
if self._dimension is None:
|
|
84
|
+
self._load_model()
|
|
85
|
+
return self._dimension or 384 # Default for all-MiniLM-L6-v2
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class AzureEmbedder(EmbeddingProvider):
|
|
89
|
+
"""
|
|
90
|
+
Azure OpenAI embeddings.
|
|
91
|
+
|
|
92
|
+
Uses text-embedding-3-small by default (1536 dimensions).
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
endpoint: Optional[str] = None,
|
|
98
|
+
api_key: Optional[str] = None,
|
|
99
|
+
deployment: str = "text-embedding-3-small",
|
|
100
|
+
api_version: str = "2024-02-01",
|
|
101
|
+
):
|
|
102
|
+
"""
|
|
103
|
+
Initialize Azure OpenAI embedder.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
endpoint: Azure OpenAI endpoint (or use AZURE_OPENAI_ENDPOINT env var)
|
|
107
|
+
api_key: Azure OpenAI API key (or use AZURE_OPENAI_KEY env var)
|
|
108
|
+
deployment: Deployment name for embedding model
|
|
109
|
+
api_version: API version
|
|
110
|
+
"""
|
|
111
|
+
import os
|
|
112
|
+
|
|
113
|
+
self.endpoint = endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
|
114
|
+
self.api_key = api_key or os.environ.get("AZURE_OPENAI_KEY")
|
|
115
|
+
self.deployment = deployment
|
|
116
|
+
self.api_version = api_version
|
|
117
|
+
self._client = None
|
|
118
|
+
self._dimension = 1536 # Default for text-embedding-3-small
|
|
119
|
+
|
|
120
|
+
if not self.endpoint:
|
|
121
|
+
raise ValueError(
|
|
122
|
+
"Azure OpenAI endpoint required. Set AZURE_OPENAI_ENDPOINT env var "
|
|
123
|
+
"or pass endpoint parameter."
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
def _get_client(self):
|
|
127
|
+
"""Get or create Azure OpenAI client."""
|
|
128
|
+
if self._client is None:
|
|
129
|
+
try:
|
|
130
|
+
from openai import AzureOpenAI
|
|
131
|
+
|
|
132
|
+
self._client = AzureOpenAI(
|
|
133
|
+
azure_endpoint=self.endpoint,
|
|
134
|
+
api_key=self.api_key,
|
|
135
|
+
api_version=self.api_version,
|
|
136
|
+
)
|
|
137
|
+
except ImportError:
|
|
138
|
+
raise ImportError(
|
|
139
|
+
"openai is required for Azure embeddings. "
|
|
140
|
+
"Install with: pip install openai"
|
|
141
|
+
)
|
|
142
|
+
return self._client
|
|
143
|
+
|
|
144
|
+
def encode(self, text: str) -> List[float]:
|
|
145
|
+
"""Generate embedding for text."""
|
|
146
|
+
client = self._get_client()
|
|
147
|
+
response = client.embeddings.create(
|
|
148
|
+
input=text,
|
|
149
|
+
model=self.deployment,
|
|
150
|
+
)
|
|
151
|
+
return response.data[0].embedding
|
|
152
|
+
|
|
153
|
+
def encode_batch(self, texts: List[str]) -> List[List[float]]:
|
|
154
|
+
"""Generate embeddings for multiple texts."""
|
|
155
|
+
client = self._get_client()
|
|
156
|
+
response = client.embeddings.create(
|
|
157
|
+
input=texts,
|
|
158
|
+
model=self.deployment,
|
|
159
|
+
)
|
|
160
|
+
# Sort by index to ensure order matches input
|
|
161
|
+
sorted_data = sorted(response.data, key=lambda x: x.index)
|
|
162
|
+
return [item.embedding for item in sorted_data]
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def dimension(self) -> int:
|
|
166
|
+
"""Return embedding dimension."""
|
|
167
|
+
return self._dimension
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class MockEmbedder(EmbeddingProvider):
|
|
171
|
+
"""
|
|
172
|
+
Mock embedder for testing.
|
|
173
|
+
|
|
174
|
+
Generates deterministic fake embeddings based on text hash.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
def __init__(self, dimension: int = 384):
|
|
178
|
+
"""Initialize mock embedder."""
|
|
179
|
+
self._dimension = dimension
|
|
180
|
+
|
|
181
|
+
def encode(self, text: str) -> List[float]:
|
|
182
|
+
"""Generate fake embedding based on text hash."""
|
|
183
|
+
import hashlib
|
|
184
|
+
|
|
185
|
+
# Create deterministic embedding from text hash
|
|
186
|
+
hash_bytes = hashlib.sha256(text.encode()).digest()
|
|
187
|
+
# Use first N bytes to create float values
|
|
188
|
+
embedding = []
|
|
189
|
+
for i in range(self._dimension):
|
|
190
|
+
byte_val = hash_bytes[i % len(hash_bytes)]
|
|
191
|
+
# Normalize to [-1, 1] range
|
|
192
|
+
embedding.append((byte_val / 127.5) - 1.0)
|
|
193
|
+
return embedding
|
|
194
|
+
|
|
195
|
+
def encode_batch(self, texts: List[str]) -> List[List[float]]:
|
|
196
|
+
"""Generate fake embeddings for multiple texts."""
|
|
197
|
+
return [self.encode(text) for text in texts]
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def dimension(self) -> int:
|
|
201
|
+
"""Return embedding dimension."""
|
|
202
|
+
return self._dimension
|
alma/retrieval/engine.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ALMA Retrieval Engine.
|
|
3
|
+
|
|
4
|
+
Handles semantic search and memory retrieval with scoring and caching.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import time
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Optional, List, Dict, Any
|
|
10
|
+
|
|
11
|
+
from alma.types import MemorySlice, MemoryScope
|
|
12
|
+
from alma.storage.base import StorageBackend
|
|
13
|
+
from alma.retrieval.scoring import MemoryScorer, ScoringWeights, ScoredItem
|
|
14
|
+
from alma.retrieval.cache import RetrievalCache, NullCache
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RetrievalEngine:
|
|
20
|
+
"""
|
|
21
|
+
Retrieves relevant memories for task context injection.
|
|
22
|
+
|
|
23
|
+
Features:
|
|
24
|
+
- Semantic search via embeddings
|
|
25
|
+
- Recency weighting (newer memories preferred)
|
|
26
|
+
- Success rate weighting (proven strategies ranked higher)
|
|
27
|
+
- Caching for repeated queries
|
|
28
|
+
- Configurable scoring weights
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
storage: StorageBackend,
|
|
34
|
+
embedding_provider: str = "local",
|
|
35
|
+
cache_ttl_seconds: int = 300,
|
|
36
|
+
enable_cache: bool = True,
|
|
37
|
+
max_cache_entries: int = 1000,
|
|
38
|
+
scoring_weights: Optional[ScoringWeights] = None,
|
|
39
|
+
recency_half_life_days: float = 30.0,
|
|
40
|
+
min_score_threshold: float = 0.2,
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Initialize retrieval engine.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
storage: Storage backend to query
|
|
47
|
+
embedding_provider: "local" (sentence-transformers) or "azure" (Azure OpenAI)
|
|
48
|
+
cache_ttl_seconds: How long to cache query results
|
|
49
|
+
enable_cache: Whether to enable caching
|
|
50
|
+
max_cache_entries: Maximum cache entries before eviction
|
|
51
|
+
scoring_weights: Custom weights for similarity/recency/success/confidence
|
|
52
|
+
recency_half_life_days: Days after which recency score halves
|
|
53
|
+
min_score_threshold: Minimum score to include in results
|
|
54
|
+
"""
|
|
55
|
+
self.storage = storage
|
|
56
|
+
self.embedding_provider = embedding_provider
|
|
57
|
+
self.min_score_threshold = min_score_threshold
|
|
58
|
+
self._embedder = None
|
|
59
|
+
|
|
60
|
+
# Initialize scorer
|
|
61
|
+
self.scorer = MemoryScorer(
|
|
62
|
+
weights=scoring_weights or ScoringWeights(),
|
|
63
|
+
recency_half_life_days=recency_half_life_days,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Initialize cache
|
|
67
|
+
if enable_cache:
|
|
68
|
+
self.cache = RetrievalCache(
|
|
69
|
+
ttl_seconds=cache_ttl_seconds,
|
|
70
|
+
max_entries=max_cache_entries,
|
|
71
|
+
)
|
|
72
|
+
else:
|
|
73
|
+
self.cache = NullCache()
|
|
74
|
+
|
|
75
|
+
def retrieve(
|
|
76
|
+
self,
|
|
77
|
+
query: str,
|
|
78
|
+
agent: str,
|
|
79
|
+
project_id: str,
|
|
80
|
+
user_id: Optional[str] = None,
|
|
81
|
+
top_k: int = 5,
|
|
82
|
+
scope: Optional[MemoryScope] = None,
|
|
83
|
+
bypass_cache: bool = False,
|
|
84
|
+
) -> MemorySlice:
|
|
85
|
+
"""
|
|
86
|
+
Retrieve relevant memories for a task.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
query: Task description to find relevant memories for
|
|
90
|
+
agent: Agent requesting memories
|
|
91
|
+
project_id: Project context
|
|
92
|
+
user_id: Optional user for preference retrieval
|
|
93
|
+
top_k: Max items per memory type
|
|
94
|
+
scope: Agent's learning scope for filtering
|
|
95
|
+
bypass_cache: Skip cache lookup/storage
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
MemorySlice with relevant memories, scored and ranked
|
|
99
|
+
"""
|
|
100
|
+
start_time = time.time()
|
|
101
|
+
|
|
102
|
+
# Check cache first
|
|
103
|
+
if not bypass_cache:
|
|
104
|
+
cached = self.cache.get(query, agent, project_id, user_id, top_k)
|
|
105
|
+
if cached is not None:
|
|
106
|
+
cached.retrieval_time_ms = int((time.time() - start_time) * 1000)
|
|
107
|
+
logger.debug(f"Cache hit for query: {query[:50]}...")
|
|
108
|
+
return cached
|
|
109
|
+
|
|
110
|
+
# Generate embedding for query
|
|
111
|
+
query_embedding = self._get_embedding(query)
|
|
112
|
+
|
|
113
|
+
# Retrieve raw items from storage (with vector search)
|
|
114
|
+
raw_heuristics = self.storage.get_heuristics(
|
|
115
|
+
project_id=project_id,
|
|
116
|
+
agent=agent,
|
|
117
|
+
embedding=query_embedding,
|
|
118
|
+
top_k=top_k * 2, # Get extra for scoring/filtering
|
|
119
|
+
min_confidence=0.0, # Let scorer handle filtering
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
raw_outcomes = self.storage.get_outcomes(
|
|
123
|
+
project_id=project_id,
|
|
124
|
+
agent=agent,
|
|
125
|
+
embedding=query_embedding,
|
|
126
|
+
top_k=top_k * 2,
|
|
127
|
+
success_only=False,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
raw_domain_knowledge = self.storage.get_domain_knowledge(
|
|
131
|
+
project_id=project_id,
|
|
132
|
+
agent=agent,
|
|
133
|
+
embedding=query_embedding,
|
|
134
|
+
top_k=top_k * 2,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
raw_anti_patterns = self.storage.get_anti_patterns(
|
|
138
|
+
project_id=project_id,
|
|
139
|
+
agent=agent,
|
|
140
|
+
embedding=query_embedding,
|
|
141
|
+
top_k=top_k * 2,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Score and rank each type
|
|
145
|
+
scored_heuristics = self.scorer.score_heuristics(raw_heuristics)
|
|
146
|
+
scored_outcomes = self.scorer.score_outcomes(raw_outcomes)
|
|
147
|
+
scored_knowledge = self.scorer.score_domain_knowledge(raw_domain_knowledge)
|
|
148
|
+
scored_anti_patterns = self.scorer.score_anti_patterns(raw_anti_patterns)
|
|
149
|
+
|
|
150
|
+
# Apply threshold and limit
|
|
151
|
+
final_heuristics = self._extract_top_k(scored_heuristics, top_k)
|
|
152
|
+
final_outcomes = self._extract_top_k(scored_outcomes, top_k)
|
|
153
|
+
final_knowledge = self._extract_top_k(scored_knowledge, top_k)
|
|
154
|
+
final_anti_patterns = self._extract_top_k(scored_anti_patterns, top_k)
|
|
155
|
+
|
|
156
|
+
# Get user preferences (not scored, just retrieved)
|
|
157
|
+
preferences = []
|
|
158
|
+
if user_id:
|
|
159
|
+
preferences = self.storage.get_user_preferences(user_id=user_id)
|
|
160
|
+
|
|
161
|
+
retrieval_time_ms = int((time.time() - start_time) * 1000)
|
|
162
|
+
|
|
163
|
+
result = MemorySlice(
|
|
164
|
+
heuristics=final_heuristics,
|
|
165
|
+
outcomes=final_outcomes,
|
|
166
|
+
preferences=preferences,
|
|
167
|
+
domain_knowledge=final_knowledge,
|
|
168
|
+
anti_patterns=final_anti_patterns,
|
|
169
|
+
query=query,
|
|
170
|
+
agent=agent,
|
|
171
|
+
retrieval_time_ms=retrieval_time_ms,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# Cache result
|
|
175
|
+
if not bypass_cache:
|
|
176
|
+
self.cache.set(query, agent, project_id, result, user_id, top_k)
|
|
177
|
+
|
|
178
|
+
logger.info(
|
|
179
|
+
f"Retrieved {result.total_items} memories for '{query[:50]}...' "
|
|
180
|
+
f"in {retrieval_time_ms}ms"
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
return result
|
|
184
|
+
|
|
185
|
+
def _extract_top_k(
|
|
186
|
+
self,
|
|
187
|
+
scored_items: List[ScoredItem],
|
|
188
|
+
top_k: int,
|
|
189
|
+
) -> List[Any]:
|
|
190
|
+
"""
|
|
191
|
+
Extract top-k items after filtering by score threshold.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
scored_items: Scored and sorted items
|
|
195
|
+
top_k: Maximum number to return
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
List of original items (unwrapped from ScoredItem)
|
|
199
|
+
"""
|
|
200
|
+
filtered = self.scorer.apply_score_threshold(
|
|
201
|
+
scored_items, self.min_score_threshold
|
|
202
|
+
)
|
|
203
|
+
return [item.item for item in filtered[:top_k]]
|
|
204
|
+
|
|
205
|
+
def _get_embedding(self, text: str) -> List[float]:
|
|
206
|
+
"""
|
|
207
|
+
Generate embedding for text.
|
|
208
|
+
|
|
209
|
+
Uses lazy initialization of embedding model.
|
|
210
|
+
"""
|
|
211
|
+
if self._embedder is None:
|
|
212
|
+
self._embedder = self._init_embedder()
|
|
213
|
+
|
|
214
|
+
return self._embedder.encode(text)
|
|
215
|
+
|
|
216
|
+
def _init_embedder(self):
|
|
217
|
+
"""Initialize the embedding model based on provider config."""
|
|
218
|
+
if self.embedding_provider == "azure":
|
|
219
|
+
from alma.retrieval.embeddings import AzureEmbedder
|
|
220
|
+
return AzureEmbedder()
|
|
221
|
+
elif self.embedding_provider == "mock":
|
|
222
|
+
from alma.retrieval.embeddings import MockEmbedder
|
|
223
|
+
return MockEmbedder()
|
|
224
|
+
else:
|
|
225
|
+
from alma.retrieval.embeddings import LocalEmbedder
|
|
226
|
+
return LocalEmbedder()
|
|
227
|
+
|
|
228
|
+
def invalidate_cache(
|
|
229
|
+
self,
|
|
230
|
+
agent: Optional[str] = None,
|
|
231
|
+
project_id: Optional[str] = None,
|
|
232
|
+
):
|
|
233
|
+
"""
|
|
234
|
+
Invalidate cache entries.
|
|
235
|
+
|
|
236
|
+
Should be called after memory updates to ensure fresh results.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
agent: Invalidate entries for this agent
|
|
240
|
+
project_id: Invalidate entries for this project
|
|
241
|
+
"""
|
|
242
|
+
self.cache.invalidate(agent=agent, project_id=project_id)
|
|
243
|
+
|
|
244
|
+
def get_cache_stats(self) -> Dict[str, Any]:
|
|
245
|
+
"""Get cache performance statistics."""
|
|
246
|
+
stats = self.cache.get_stats()
|
|
247
|
+
return stats.to_dict()
|
|
248
|
+
|
|
249
|
+
def clear_cache(self):
|
|
250
|
+
"""Clear all cached results."""
|
|
251
|
+
self.cache.clear()
|
|
252
|
+
|
|
253
|
+
def get_scorer_weights(self) -> Dict[str, float]:
|
|
254
|
+
"""Get current scoring weights."""
|
|
255
|
+
w = self.scorer.weights
|
|
256
|
+
return {
|
|
257
|
+
"similarity": w.similarity,
|
|
258
|
+
"recency": w.recency,
|
|
259
|
+
"success_rate": w.success_rate,
|
|
260
|
+
"confidence": w.confidence,
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
def update_scorer_weights(
|
|
264
|
+
self,
|
|
265
|
+
similarity: Optional[float] = None,
|
|
266
|
+
recency: Optional[float] = None,
|
|
267
|
+
success_rate: Optional[float] = None,
|
|
268
|
+
confidence: Optional[float] = None,
|
|
269
|
+
):
|
|
270
|
+
"""
|
|
271
|
+
Update scoring weights (will be normalized to sum to 1.0).
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
similarity: Weight for semantic similarity
|
|
275
|
+
recency: Weight for recency
|
|
276
|
+
success_rate: Weight for success rate
|
|
277
|
+
confidence: Weight for stored confidence
|
|
278
|
+
"""
|
|
279
|
+
current = self.scorer.weights
|
|
280
|
+
self.scorer.weights = ScoringWeights(
|
|
281
|
+
similarity=similarity if similarity is not None else current.similarity,
|
|
282
|
+
recency=recency if recency is not None else current.recency,
|
|
283
|
+
success_rate=success_rate if success_rate is not None else current.success_rate,
|
|
284
|
+
confidence=confidence if confidence is not None else current.confidence,
|
|
285
|
+
)
|
|
286
|
+
# Clear cache since scoring changed
|
|
287
|
+
self.cache.clear()
|