flock-core 0.5.20__py3-none-any.whl → 0.5.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flock-core might be problematic. Click here for more details.

@@ -0,0 +1,173 @@
1
+ """Semantic context providers for agent execution.
2
+
3
+ This module provides context providers that use semantic similarity to find
4
+ relevant historical artifacts for agent context.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from collections.abc import Callable
10
+ from typing import TYPE_CHECKING, Any
11
+
12
+ from pydantic import BaseModel
13
+
14
+
15
+ if TYPE_CHECKING:
16
+ from flock.core.artifacts import Artifact
17
+ from flock.core.store import ArtifactStore
18
+
19
+
20
+ class SemanticContextProvider:
21
+ """Context provider that retrieves semantically relevant historical artifacts.
22
+
23
+ This provider uses semantic similarity to find artifacts that are relevant
24
+ to a given query text, enabling agents to make decisions based on similar
25
+ past events.
26
+
27
+ Args:
28
+ query_text: The semantic query to match against artifacts
29
+ threshold: Minimum similarity score (0.0 to 1.0) to include in results
30
+ limit: Maximum number of artifacts to return
31
+ extract_field: Optional field name to extract from artifact payload for matching.
32
+ If None, uses all text from payload.
33
+ artifact_type: Optional type filter - only return artifacts of this type
34
+ where: Optional predicate filter for additional filtering
35
+
36
+ Example:
37
+ ```python
38
+ provider = SemanticContextProvider(
39
+ query_text="user authentication issues", threshold=0.5, limit=5
40
+ )
41
+
42
+ relevant_artifacts = await provider.get_context(store)
43
+ ```
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ query_text: str,
49
+ threshold: float = 0.4,
50
+ limit: int = 10,
51
+ extract_field: str | None = None,
52
+ artifact_type: type[BaseModel] | None = None,
53
+ where: Callable[[Artifact], bool] | None = None,
54
+ ):
55
+ """Initialize semantic context provider.
56
+
57
+ Args:
58
+ query_text: The semantic query text
59
+ threshold: Minimum similarity score (default: 0.4)
60
+ limit: Maximum results to return (default: 10)
61
+ extract_field: Optional field to extract from payload
62
+ artifact_type: Optional type filter
63
+ where: Optional predicate for additional filtering
64
+ """
65
+ if not query_text or not query_text.strip():
66
+ raise ValueError("query_text cannot be empty")
67
+
68
+ if not 0.0 <= threshold <= 1.0:
69
+ raise ValueError("threshold must be between 0 and 1")
70
+
71
+ if limit < 1:
72
+ raise ValueError("limit must be at least 1")
73
+
74
+ self.query_text = query_text
75
+ self.threshold = threshold
76
+ self.limit = limit
77
+ self.extract_field = extract_field
78
+ self.artifact_type = artifact_type
79
+ self.where = where
80
+
81
+ async def get_context(self, store: ArtifactStore) -> list[Artifact]:
82
+ """Retrieve semantically relevant artifacts from store.
83
+
84
+ Args:
85
+ store: The artifact store to query
86
+
87
+ Returns:
88
+ List of relevant artifacts, sorted by similarity (highest first)
89
+ """
90
+ # Check if semantic features available
91
+ try:
92
+ from flock.semantic import SEMANTIC_AVAILABLE, EmbeddingService
93
+ except ImportError:
94
+ return []
95
+
96
+ if not SEMANTIC_AVAILABLE:
97
+ return []
98
+
99
+ try:
100
+ embedding_service = EmbeddingService.get_instance()
101
+ except Exception:
102
+ return []
103
+
104
+ # Get query embedding
105
+ try:
106
+ query_embedding = embedding_service.embed(self.query_text)
107
+ except Exception:
108
+ return []
109
+
110
+ # Get all artifacts from store
111
+ all_artifacts = await store.list()
112
+
113
+ # Filter by type if specified
114
+ if self.artifact_type:
115
+ from flock.registry import type_registry
116
+
117
+ type_name = type_registry.register(self.artifact_type)
118
+ all_artifacts = [a for a in all_artifacts if a.type == type_name]
119
+
120
+ # Filter by where clause if specified
121
+ if self.where:
122
+ all_artifacts = [a for a in all_artifacts if self.where(a)]
123
+
124
+ # Compute similarities and filter
125
+ results: list[tuple[Artifact, float]] = []
126
+
127
+ for artifact in all_artifacts:
128
+ try:
129
+ # Extract text from artifact
130
+ if self.extract_field:
131
+ # Use specific field
132
+ text = str(artifact.payload.get(self.extract_field, ""))
133
+ else:
134
+ # Use all text from payload
135
+ text = self._extract_text_from_payload(artifact.payload)
136
+
137
+ if not text or not text.strip():
138
+ continue
139
+
140
+ # Compute similarity
141
+ similarity = embedding_service.similarity(self.query_text, text)
142
+
143
+ # Check threshold
144
+ if similarity >= self.threshold:
145
+ results.append((artifact, similarity))
146
+
147
+ except Exception:
148
+ # Skip artifacts that fail processing
149
+ continue
150
+
151
+ # Sort by similarity (highest first) and take top N
152
+ results.sort(key=lambda x: x[1], reverse=True)
153
+ return [artifact for artifact, _ in results[: self.limit]]
154
+
155
+ def _extract_text_from_payload(self, payload: dict[str, Any]) -> str:
156
+ """Extract all text content from payload.
157
+
158
+ Args:
159
+ payload: The artifact payload dict
160
+
161
+ Returns:
162
+ str: Concatenated text from all string fields
163
+ """
164
+ text_parts = []
165
+ for value in payload.values():
166
+ if isinstance(value, str):
167
+ text_parts.append(value)
168
+ elif isinstance(value, (list, tuple)):
169
+ for item in value:
170
+ if isinstance(item, str):
171
+ text_parts.append(item)
172
+
173
+ return " ".join(text_parts)
@@ -0,0 +1,235 @@
1
+ """Embedding service for semantic matching.
2
+
3
+ This module provides a singleton service for generating and caching embeddings
4
+ using sentence-transformers.
5
+ """
6
+
7
+ import logging
8
+ from collections import OrderedDict
9
+
10
+ import numpy as np
11
+
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class LRUCache:
17
+ """Simple LRU cache with size limit."""
18
+
19
+ def __init__(self, max_size: int = 10000):
20
+ """Initialize LRU cache.
21
+
22
+ Args:
23
+ max_size: Maximum number of entries
24
+ """
25
+ self.max_size = max_size
26
+ self._cache: OrderedDict[str, np.ndarray] = OrderedDict()
27
+
28
+ def get(self, key: str) -> np.ndarray | None:
29
+ """Get value and mark as recently used."""
30
+ if key not in self._cache:
31
+ return None
32
+ # Move to end (most recent)
33
+ self._cache.move_to_end(key)
34
+ return self._cache[key]
35
+
36
+ def put(self, key: str, value: np.ndarray) -> None:
37
+ """Put value and evict LRU if needed."""
38
+ if key in self._cache:
39
+ # Update and move to end
40
+ self._cache.move_to_end(key)
41
+ self._cache[key] = value
42
+
43
+ # Evict oldest if over limit
44
+ if len(self._cache) > self.max_size:
45
+ self._cache.popitem(last=False) # Remove oldest (first item)
46
+
47
+ def __contains__(self, key: str) -> bool:
48
+ """Check if key exists in cache."""
49
+ return key in self._cache
50
+
51
+ def __len__(self) -> int:
52
+ """Get cache size."""
53
+ return len(self._cache)
54
+
55
+
56
+ class EmbeddingService:
57
+ """Singleton service for text embeddings using sentence-transformers.
58
+
59
+ This class manages the lifecycle of the embedding model and provides
60
+ efficient caching of embeddings.
61
+ """
62
+
63
+ _instance = None
64
+
65
+ def __init__(self, cache_size: int = 10000):
66
+ """Private constructor - use get_instance() instead.
67
+
68
+ Args:
69
+ cache_size: Maximum number of embeddings to cache
70
+ """
71
+ self._model = None
72
+ self._cache = LRUCache(max_size=cache_size)
73
+ self._cache_size = cache_size
74
+ self._hits = 0
75
+ self._misses = 0
76
+
77
+ @staticmethod
78
+ def get_instance(cache_size: int = 10000):
79
+ """Get or create the singleton EmbeddingService instance.
80
+
81
+ Args:
82
+ cache_size: Maximum number of embeddings to cache (default: 10000)
83
+
84
+ Returns:
85
+ EmbeddingService: The singleton instance
86
+ """
87
+ if EmbeddingService._instance is None:
88
+ EmbeddingService._instance = EmbeddingService(cache_size=cache_size)
89
+ return EmbeddingService._instance
90
+
91
+ def _load_model(self):
92
+ """Lazy load the sentence-transformers model."""
93
+ if self._model is None:
94
+ from sentence_transformers import SentenceTransformer
95
+
96
+ logger.info("Loading sentence-transformers model: all-MiniLM-L6-v2")
97
+ self._model = SentenceTransformer("all-MiniLM-L6-v2")
98
+ logger.info("Model loaded successfully")
99
+
100
+ def embed(self, text: str) -> np.ndarray:
101
+ """Generate embedding for a single text.
102
+
103
+ Args:
104
+ text: The text to embed
105
+
106
+ Returns:
107
+ np.ndarray: 384-dimensional embedding vector
108
+
109
+ Raises:
110
+ ValueError: If text is empty
111
+ """
112
+ if not text or not text.strip():
113
+ raise ValueError("Cannot embed empty text")
114
+
115
+ # Check cache first
116
+ cached = self._cache.get(text)
117
+ if cached is not None:
118
+ self._hits += 1
119
+ return cached
120
+
121
+ # Cache miss - generate embedding
122
+ self._misses += 1
123
+ self._load_model()
124
+
125
+ # Generate embedding
126
+ embedding = self._model.encode(
127
+ text, convert_to_numpy=True, show_progress_bar=False
128
+ )
129
+
130
+ # Ensure it's a float32 numpy array and flatten to 1D
131
+ if not isinstance(embedding, np.ndarray):
132
+ embedding = np.array(embedding, dtype=np.float32)
133
+
134
+ # Flatten to 1D if needed (model might return (1, 384) for single text)
135
+ if embedding.ndim > 1:
136
+ embedding = embedding.flatten()
137
+
138
+ # Store in cache
139
+ self._cache.put(text, embedding)
140
+
141
+ return embedding
142
+
143
+ def embed_batch(self, texts: list[str]) -> list[np.ndarray]:
144
+ """Generate embeddings for multiple texts efficiently.
145
+
146
+ Args:
147
+ texts: List of texts to embed
148
+
149
+ Returns:
150
+ list[np.ndarray]: List of embedding vectors
151
+ """
152
+ if not texts:
153
+ return []
154
+
155
+ # Separate cached and uncached
156
+ results = [None] * len(texts)
157
+ to_encode = []
158
+ to_encode_indices = []
159
+
160
+ for i, text in enumerate(texts):
161
+ cached = self._cache.get(text)
162
+ if cached is not None:
163
+ results[i] = cached
164
+ self._hits += 1
165
+ else:
166
+ to_encode.append(text)
167
+ to_encode_indices.append(i)
168
+ self._misses += 1
169
+
170
+ # Batch encode uncached texts
171
+ if to_encode:
172
+ self._load_model()
173
+ embeddings = self._model.encode(
174
+ to_encode, convert_to_numpy=True, show_progress_bar=False
175
+ )
176
+
177
+ # Store in cache and results
178
+ for i, (text, embedding) in enumerate(
179
+ zip(to_encode, embeddings, strict=False)
180
+ ):
181
+ if not isinstance(embedding, np.ndarray):
182
+ embedding = np.array(embedding, dtype=np.float32)
183
+ # Flatten to 1D if needed
184
+ if embedding.ndim > 1:
185
+ embedding = embedding.flatten()
186
+ self._cache.put(text, embedding)
187
+ results[to_encode_indices[i]] = embedding
188
+
189
+ return results # type: ignore
190
+
191
+ def similarity(self, text1: str, text2: str) -> float:
192
+ """Compute semantic similarity between two texts.
193
+
194
+ Uses cosine similarity between embeddings.
195
+
196
+ Args:
197
+ text1: First text
198
+ text2: Second text
199
+
200
+ Returns:
201
+ float: Similarity score between 0 and 1
202
+ """
203
+ emb1 = self.embed(text1)
204
+ emb2 = self.embed(text2)
205
+
206
+ # Compute cosine similarity
207
+ dot_product = np.dot(emb1, emb2)
208
+ norm1 = np.linalg.norm(emb1)
209
+ norm2 = np.linalg.norm(emb2)
210
+
211
+ if norm1 == 0 or norm2 == 0:
212
+ return 0.0
213
+
214
+ similarity = dot_product / (norm1 * norm2)
215
+
216
+ # Clamp to [0, 1] and handle floating point errors
217
+ return float(max(0.0, min(1.0, similarity)))
218
+
219
+ def get_cache_stats(self) -> dict:
220
+ """Get cache hit/miss statistics.
221
+
222
+ Returns:
223
+ dict: Statistics including hits, misses, and hit rate
224
+ """
225
+ total = self._hits + self._misses
226
+ hit_rate = self._hits / total if total > 0 else 0.0
227
+
228
+ return {
229
+ "hits": self._hits,
230
+ "misses": self._misses,
231
+ "total": total,
232
+ "hit_rate": hit_rate,
233
+ "cache_size": len(self._cache),
234
+ "cache_limit": self._cache_size,
235
+ }