alma-memory 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,202 @@
1
+ """
2
+ ALMA Embedding Providers.
3
+
4
+ Supports local (sentence-transformers) and Azure OpenAI embeddings.
5
+ """
6
+
7
+ import logging
8
+ from abc import ABC, abstractmethod
9
+ from typing import List, Optional
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class EmbeddingProvider(ABC):
15
+ """Abstract base class for embedding providers."""
16
+
17
+ @abstractmethod
18
+ def encode(self, text: str) -> List[float]:
19
+ """Generate embedding for text."""
20
+ pass
21
+
22
+ @abstractmethod
23
+ def encode_batch(self, texts: List[str]) -> List[List[float]]:
24
+ """Generate embeddings for multiple texts."""
25
+ pass
26
+
27
+ @property
28
+ @abstractmethod
29
+ def dimension(self) -> int:
30
+ """Return embedding dimension."""
31
+ pass
32
+
33
+
34
+ class LocalEmbedder(EmbeddingProvider):
35
+ """
36
+ Local embeddings using sentence-transformers.
37
+
38
+ Default model: all-MiniLM-L6-v2 (384 dimensions, fast, good quality)
39
+ """
40
+
41
+ def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
42
+ """
43
+ Initialize local embedder.
44
+
45
+ Args:
46
+ model_name: Sentence-transformers model name
47
+ """
48
+ self.model_name = model_name
49
+ self._model = None
50
+ self._dimension: Optional[int] = None
51
+
52
+ def _load_model(self):
53
+ """Lazy load the model."""
54
+ if self._model is None:
55
+ try:
56
+ from sentence_transformers import SentenceTransformer
57
+
58
+ logger.info(f"Loading embedding model: {self.model_name}")
59
+ self._model = SentenceTransformer(self.model_name)
60
+ self._dimension = self._model.get_sentence_embedding_dimension()
61
+ logger.info(f"Model loaded, dimension: {self._dimension}")
62
+ except ImportError:
63
+ raise ImportError(
64
+ "sentence-transformers is required for local embeddings. "
65
+ "Install with: pip install sentence-transformers"
66
+ )
67
+
68
+ def encode(self, text: str) -> List[float]:
69
+ """Generate embedding for text."""
70
+ self._load_model()
71
+ embedding = self._model.encode(text, normalize_embeddings=True)
72
+ return embedding.tolist()
73
+
74
+ def encode_batch(self, texts: List[str]) -> List[List[float]]:
75
+ """Generate embeddings for multiple texts."""
76
+ self._load_model()
77
+ embeddings = self._model.encode(texts, normalize_embeddings=True)
78
+ return [emb.tolist() for emb in embeddings]
79
+
80
+ @property
81
+ def dimension(self) -> int:
82
+ """Return embedding dimension."""
83
+ if self._dimension is None:
84
+ self._load_model()
85
+ return self._dimension or 384 # Default for all-MiniLM-L6-v2
86
+
87
+
88
+ class AzureEmbedder(EmbeddingProvider):
89
+ """
90
+ Azure OpenAI embeddings.
91
+
92
+ Uses text-embedding-3-small by default (1536 dimensions).
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ endpoint: Optional[str] = None,
98
+ api_key: Optional[str] = None,
99
+ deployment: str = "text-embedding-3-small",
100
+ api_version: str = "2024-02-01",
101
+ ):
102
+ """
103
+ Initialize Azure OpenAI embedder.
104
+
105
+ Args:
106
+ endpoint: Azure OpenAI endpoint (or use AZURE_OPENAI_ENDPOINT env var)
107
+ api_key: Azure OpenAI API key (or use AZURE_OPENAI_KEY env var)
108
+ deployment: Deployment name for embedding model
109
+ api_version: API version
110
+ """
111
+ import os
112
+
113
+ self.endpoint = endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
114
+ self.api_key = api_key or os.environ.get("AZURE_OPENAI_KEY")
115
+ self.deployment = deployment
116
+ self.api_version = api_version
117
+ self._client = None
118
+ self._dimension = 1536 # Default for text-embedding-3-small
119
+
120
+ if not self.endpoint:
121
+ raise ValueError(
122
+ "Azure OpenAI endpoint required. Set AZURE_OPENAI_ENDPOINT env var "
123
+ "or pass endpoint parameter."
124
+ )
125
+
126
+ def _get_client(self):
127
+ """Get or create Azure OpenAI client."""
128
+ if self._client is None:
129
+ try:
130
+ from openai import AzureOpenAI
131
+
132
+ self._client = AzureOpenAI(
133
+ azure_endpoint=self.endpoint,
134
+ api_key=self.api_key,
135
+ api_version=self.api_version,
136
+ )
137
+ except ImportError:
138
+ raise ImportError(
139
+ "openai is required for Azure embeddings. "
140
+ "Install with: pip install openai"
141
+ )
142
+ return self._client
143
+
144
+ def encode(self, text: str) -> List[float]:
145
+ """Generate embedding for text."""
146
+ client = self._get_client()
147
+ response = client.embeddings.create(
148
+ input=text,
149
+ model=self.deployment,
150
+ )
151
+ return response.data[0].embedding
152
+
153
+ def encode_batch(self, texts: List[str]) -> List[List[float]]:
154
+ """Generate embeddings for multiple texts."""
155
+ client = self._get_client()
156
+ response = client.embeddings.create(
157
+ input=texts,
158
+ model=self.deployment,
159
+ )
160
+ # Sort by index to ensure order matches input
161
+ sorted_data = sorted(response.data, key=lambda x: x.index)
162
+ return [item.embedding for item in sorted_data]
163
+
164
+ @property
165
+ def dimension(self) -> int:
166
+ """Return embedding dimension."""
167
+ return self._dimension
168
+
169
+
170
+ class MockEmbedder(EmbeddingProvider):
171
+ """
172
+ Mock embedder for testing.
173
+
174
+ Generates deterministic fake embeddings based on text hash.
175
+ """
176
+
177
+ def __init__(self, dimension: int = 384):
178
+ """Initialize mock embedder."""
179
+ self._dimension = dimension
180
+
181
+ def encode(self, text: str) -> List[float]:
182
+ """Generate fake embedding based on text hash."""
183
+ import hashlib
184
+
185
+ # Create deterministic embedding from text hash
186
+ hash_bytes = hashlib.sha256(text.encode()).digest()
187
+ # Use first N bytes to create float values
188
+ embedding = []
189
+ for i in range(self._dimension):
190
+ byte_val = hash_bytes[i % len(hash_bytes)]
191
+ # Normalize to [-1, 1] range
192
+ embedding.append((byte_val / 127.5) - 1.0)
193
+ return embedding
194
+
195
+ def encode_batch(self, texts: List[str]) -> List[List[float]]:
196
+ """Generate fake embeddings for multiple texts."""
197
+ return [self.encode(text) for text in texts]
198
+
199
+ @property
200
+ def dimension(self) -> int:
201
+ """Return embedding dimension."""
202
+ return self._dimension
@@ -0,0 +1,287 @@
1
+ """
2
+ ALMA Retrieval Engine.
3
+
4
+ Handles semantic search and memory retrieval with scoring and caching.
5
+ """
6
+
7
+ import time
8
+ import logging
9
+ from typing import Optional, List, Dict, Any
10
+
11
+ from alma.types import MemorySlice, MemoryScope
12
+ from alma.storage.base import StorageBackend
13
+ from alma.retrieval.scoring import MemoryScorer, ScoringWeights, ScoredItem
14
+ from alma.retrieval.cache import RetrievalCache, NullCache
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class RetrievalEngine:
20
+ """
21
+ Retrieves relevant memories for task context injection.
22
+
23
+ Features:
24
+ - Semantic search via embeddings
25
+ - Recency weighting (newer memories preferred)
26
+ - Success rate weighting (proven strategies ranked higher)
27
+ - Caching for repeated queries
28
+ - Configurable scoring weights
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ storage: StorageBackend,
34
+ embedding_provider: str = "local",
35
+ cache_ttl_seconds: int = 300,
36
+ enable_cache: bool = True,
37
+ max_cache_entries: int = 1000,
38
+ scoring_weights: Optional[ScoringWeights] = None,
39
+ recency_half_life_days: float = 30.0,
40
+ min_score_threshold: float = 0.2,
41
+ ):
42
+ """
43
+ Initialize retrieval engine.
44
+
45
+ Args:
46
+ storage: Storage backend to query
47
+ embedding_provider: "local" (sentence-transformers) or "azure" (Azure OpenAI)
48
+ cache_ttl_seconds: How long to cache query results
49
+ enable_cache: Whether to enable caching
50
+ max_cache_entries: Maximum cache entries before eviction
51
+ scoring_weights: Custom weights for similarity/recency/success/confidence
52
+ recency_half_life_days: Days after which recency score halves
53
+ min_score_threshold: Minimum score to include in results
54
+ """
55
+ self.storage = storage
56
+ self.embedding_provider = embedding_provider
57
+ self.min_score_threshold = min_score_threshold
58
+ self._embedder = None
59
+
60
+ # Initialize scorer
61
+ self.scorer = MemoryScorer(
62
+ weights=scoring_weights or ScoringWeights(),
63
+ recency_half_life_days=recency_half_life_days,
64
+ )
65
+
66
+ # Initialize cache
67
+ if enable_cache:
68
+ self.cache = RetrievalCache(
69
+ ttl_seconds=cache_ttl_seconds,
70
+ max_entries=max_cache_entries,
71
+ )
72
+ else:
73
+ self.cache = NullCache()
74
+
75
+ def retrieve(
76
+ self,
77
+ query: str,
78
+ agent: str,
79
+ project_id: str,
80
+ user_id: Optional[str] = None,
81
+ top_k: int = 5,
82
+ scope: Optional[MemoryScope] = None,
83
+ bypass_cache: bool = False,
84
+ ) -> MemorySlice:
85
+ """
86
+ Retrieve relevant memories for a task.
87
+
88
+ Args:
89
+ query: Task description to find relevant memories for
90
+ agent: Agent requesting memories
91
+ project_id: Project context
92
+ user_id: Optional user for preference retrieval
93
+ top_k: Max items per memory type
94
+ scope: Agent's learning scope for filtering
95
+ bypass_cache: Skip cache lookup/storage
96
+
97
+ Returns:
98
+ MemorySlice with relevant memories, scored and ranked
99
+ """
100
+ start_time = time.time()
101
+
102
+ # Check cache first
103
+ if not bypass_cache:
104
+ cached = self.cache.get(query, agent, project_id, user_id, top_k)
105
+ if cached is not None:
106
+ cached.retrieval_time_ms = int((time.time() - start_time) * 1000)
107
+ logger.debug(f"Cache hit for query: {query[:50]}...")
108
+ return cached
109
+
110
+ # Generate embedding for query
111
+ query_embedding = self._get_embedding(query)
112
+
113
+ # Retrieve raw items from storage (with vector search)
114
+ raw_heuristics = self.storage.get_heuristics(
115
+ project_id=project_id,
116
+ agent=agent,
117
+ embedding=query_embedding,
118
+ top_k=top_k * 2, # Get extra for scoring/filtering
119
+ min_confidence=0.0, # Let scorer handle filtering
120
+ )
121
+
122
+ raw_outcomes = self.storage.get_outcomes(
123
+ project_id=project_id,
124
+ agent=agent,
125
+ embedding=query_embedding,
126
+ top_k=top_k * 2,
127
+ success_only=False,
128
+ )
129
+
130
+ raw_domain_knowledge = self.storage.get_domain_knowledge(
131
+ project_id=project_id,
132
+ agent=agent,
133
+ embedding=query_embedding,
134
+ top_k=top_k * 2,
135
+ )
136
+
137
+ raw_anti_patterns = self.storage.get_anti_patterns(
138
+ project_id=project_id,
139
+ agent=agent,
140
+ embedding=query_embedding,
141
+ top_k=top_k * 2,
142
+ )
143
+
144
+ # Score and rank each type
145
+ scored_heuristics = self.scorer.score_heuristics(raw_heuristics)
146
+ scored_outcomes = self.scorer.score_outcomes(raw_outcomes)
147
+ scored_knowledge = self.scorer.score_domain_knowledge(raw_domain_knowledge)
148
+ scored_anti_patterns = self.scorer.score_anti_patterns(raw_anti_patterns)
149
+
150
+ # Apply threshold and limit
151
+ final_heuristics = self._extract_top_k(scored_heuristics, top_k)
152
+ final_outcomes = self._extract_top_k(scored_outcomes, top_k)
153
+ final_knowledge = self._extract_top_k(scored_knowledge, top_k)
154
+ final_anti_patterns = self._extract_top_k(scored_anti_patterns, top_k)
155
+
156
+ # Get user preferences (not scored, just retrieved)
157
+ preferences = []
158
+ if user_id:
159
+ preferences = self.storage.get_user_preferences(user_id=user_id)
160
+
161
+ retrieval_time_ms = int((time.time() - start_time) * 1000)
162
+
163
+ result = MemorySlice(
164
+ heuristics=final_heuristics,
165
+ outcomes=final_outcomes,
166
+ preferences=preferences,
167
+ domain_knowledge=final_knowledge,
168
+ anti_patterns=final_anti_patterns,
169
+ query=query,
170
+ agent=agent,
171
+ retrieval_time_ms=retrieval_time_ms,
172
+ )
173
+
174
+ # Cache result
175
+ if not bypass_cache:
176
+ self.cache.set(query, agent, project_id, result, user_id, top_k)
177
+
178
+ logger.info(
179
+ f"Retrieved {result.total_items} memories for '{query[:50]}...' "
180
+ f"in {retrieval_time_ms}ms"
181
+ )
182
+
183
+ return result
184
+
185
+ def _extract_top_k(
186
+ self,
187
+ scored_items: List[ScoredItem],
188
+ top_k: int,
189
+ ) -> List[Any]:
190
+ """
191
+ Extract top-k items after filtering by score threshold.
192
+
193
+ Args:
194
+ scored_items: Scored and sorted items
195
+ top_k: Maximum number to return
196
+
197
+ Returns:
198
+ List of original items (unwrapped from ScoredItem)
199
+ """
200
+ filtered = self.scorer.apply_score_threshold(
201
+ scored_items, self.min_score_threshold
202
+ )
203
+ return [item.item for item in filtered[:top_k]]
204
+
205
+ def _get_embedding(self, text: str) -> List[float]:
206
+ """
207
+ Generate embedding for text.
208
+
209
+ Uses lazy initialization of embedding model.
210
+ """
211
+ if self._embedder is None:
212
+ self._embedder = self._init_embedder()
213
+
214
+ return self._embedder.encode(text)
215
+
216
+ def _init_embedder(self):
217
+ """Initialize the embedding model based on provider config."""
218
+ if self.embedding_provider == "azure":
219
+ from alma.retrieval.embeddings import AzureEmbedder
220
+ return AzureEmbedder()
221
+ elif self.embedding_provider == "mock":
222
+ from alma.retrieval.embeddings import MockEmbedder
223
+ return MockEmbedder()
224
+ else:
225
+ from alma.retrieval.embeddings import LocalEmbedder
226
+ return LocalEmbedder()
227
+
228
+ def invalidate_cache(
229
+ self,
230
+ agent: Optional[str] = None,
231
+ project_id: Optional[str] = None,
232
+ ):
233
+ """
234
+ Invalidate cache entries.
235
+
236
+ Should be called after memory updates to ensure fresh results.
237
+
238
+ Args:
239
+ agent: Invalidate entries for this agent
240
+ project_id: Invalidate entries for this project
241
+ """
242
+ self.cache.invalidate(agent=agent, project_id=project_id)
243
+
244
+ def get_cache_stats(self) -> Dict[str, Any]:
245
+ """Get cache performance statistics."""
246
+ stats = self.cache.get_stats()
247
+ return stats.to_dict()
248
+
249
+ def clear_cache(self):
250
+ """Clear all cached results."""
251
+ self.cache.clear()
252
+
253
+ def get_scorer_weights(self) -> Dict[str, float]:
254
+ """Get current scoring weights."""
255
+ w = self.scorer.weights
256
+ return {
257
+ "similarity": w.similarity,
258
+ "recency": w.recency,
259
+ "success_rate": w.success_rate,
260
+ "confidence": w.confidence,
261
+ }
262
+
263
+ def update_scorer_weights(
264
+ self,
265
+ similarity: Optional[float] = None,
266
+ recency: Optional[float] = None,
267
+ success_rate: Optional[float] = None,
268
+ confidence: Optional[float] = None,
269
+ ):
270
+ """
271
+ Update scoring weights (will be normalized to sum to 1.0).
272
+
273
+ Args:
274
+ similarity: Weight for semantic similarity
275
+ recency: Weight for recency
276
+ success_rate: Weight for success rate
277
+ confidence: Weight for stored confidence
278
+ """
279
+ current = self.scorer.weights
280
+ self.scorer.weights = ScoringWeights(
281
+ similarity=similarity if similarity is not None else current.similarity,
282
+ recency=recency if recency is not None else current.recency,
283
+ success_rate=success_rate if success_rate is not None else current.success_rate,
284
+ confidence=confidence if confidence is not None else current.confidence,
285
+ )
286
+ # Clear cache since scoring changed
287
+ self.cache.clear()