flock-core 0.5.21__py3-none-any.whl → 0.5.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flock-core might be problematic. Click here for more details.
- flock/api/models.py +3 -2
- flock/api/service.py +0 -1
- flock/core/agent.py +51 -16
- flock/core/orchestrator.py +18 -6
- flock/core/subscription.py +151 -8
- flock/semantic/__init__.py +49 -0
- flock/semantic/context_provider.py +173 -0
- flock/semantic/embedding_service.py +239 -0
- flock_core-0.5.23.dist-info/METADATA +976 -0
- {flock_core-0.5.21.dist-info → flock_core-0.5.23.dist-info}/RECORD +13 -10
- flock_core-0.5.21.dist-info/METADATA +0 -1327
- {flock_core-0.5.21.dist-info → flock_core-0.5.23.dist-info}/WHEEL +0 -0
- {flock_core-0.5.21.dist-info → flock_core-0.5.23.dist-info}/entry_points.txt +0 -0
- {flock_core-0.5.21.dist-info → flock_core-0.5.23.dist-info}/licenses/LICENSE +0 -0
flock/api/models.py
CHANGED
|
@@ -20,8 +20,9 @@ class AgentSubscription(BaseModel):
|
|
|
20
20
|
"""Subscription configuration for an agent."""
|
|
21
21
|
|
|
22
22
|
types: list[str] = Field(description="Artifact types this subscription consumes")
|
|
23
|
-
mode: str = Field(
|
|
24
|
-
|
|
23
|
+
mode: str = Field(
|
|
24
|
+
description="Subscription mode (e.g., 'both', 'direct', 'events')"
|
|
25
|
+
)
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
class Agent(BaseModel):
|
flock/api/service.py
CHANGED
flock/core/agent.py
CHANGED
|
@@ -19,7 +19,7 @@ from flock.agent.mcp_integration import MCPIntegration
|
|
|
19
19
|
# Phase 4: Import extracted modules
|
|
20
20
|
from flock.agent.output_processor import OutputProcessor
|
|
21
21
|
from flock.core.artifacts import Artifact, ArtifactSpec
|
|
22
|
-
from flock.core.subscription import BatchSpec, JoinSpec, Subscription
|
|
22
|
+
from flock.core.subscription import BatchSpec, JoinSpec, Subscription
|
|
23
23
|
from flock.core.visibility import AgentIdentity, Visibility, ensure_visibility
|
|
24
24
|
from flock.logging.auto_trace import AutoTracedMeta
|
|
25
25
|
from flock.logging.logging import get_logger
|
|
@@ -226,11 +226,7 @@ class Agent(metaclass=AutoTracedMeta):
|
|
|
226
226
|
comp_name = self._component_display_name(component)
|
|
227
227
|
priority = getattr(component, "priority", 0)
|
|
228
228
|
logger.info(
|
|
229
|
-
"Agent
|
|
230
|
-
self.name,
|
|
231
|
-
comp_name,
|
|
232
|
-
priority,
|
|
233
|
-
len(self.utilities),
|
|
229
|
+
f"Agent {self.name}: utility added: component={comp_name}, priority={priority}, total_utilities={len(self.utilities)}"
|
|
234
230
|
)
|
|
235
231
|
self.utilities.sort(key=lambda comp: getattr(comp, "priority", 0))
|
|
236
232
|
|
|
@@ -539,13 +535,16 @@ class AgentBuilder:
|
|
|
539
535
|
where: Callable[[BaseModel], bool]
|
|
540
536
|
| Sequence[Callable[[BaseModel], bool]]
|
|
541
537
|
| None = None,
|
|
542
|
-
|
|
543
|
-
|
|
538
|
+
semantic_match: str
|
|
539
|
+
| list[str]
|
|
540
|
+
| list[dict[str, Any]]
|
|
541
|
+
| dict[str, Any]
|
|
542
|
+
| None = None,
|
|
543
|
+
semantic_threshold: float = 0.0,
|
|
544
544
|
from_agents: Iterable[str] | None = None,
|
|
545
545
|
tags: Iterable[str] | None = None,
|
|
546
546
|
join: dict | JoinSpec | None = None,
|
|
547
547
|
batch: dict | BatchSpec | None = None,
|
|
548
|
-
delivery: str = "exclusive",
|
|
549
548
|
mode: str = "both",
|
|
550
549
|
priority: int = 0,
|
|
551
550
|
) -> AgentBuilder:
|
|
@@ -558,14 +557,21 @@ class AgentBuilder:
|
|
|
558
557
|
*types: Artifact types (Pydantic models) to consume
|
|
559
558
|
where: Optional filter predicate(s). Agent only executes if predicate returns True.
|
|
560
559
|
Can be a single callable or sequence of callables (all must pass).
|
|
561
|
-
|
|
562
|
-
|
|
560
|
+
semantic_match: Optional semantic similarity filter. Matches artifacts based on
|
|
561
|
+
meaning rather than keywords. Can be:
|
|
562
|
+
- str: Single query (e.g., "security vulnerability")
|
|
563
|
+
- list[str]: Multiple queries, all must match (AND logic)
|
|
564
|
+
- dict: Advanced config with "query", "threshold", "field"
|
|
565
|
+
- list[dict]: Multiple queries with individual thresholds
|
|
566
|
+
semantic_threshold: Minimum similarity threshold for semantic matching (0.0-1.0).
|
|
567
|
+
Applied to all queries when semantic_match is a string or list of strings.
|
|
568
|
+
Ignored if semantic_match is a dict/list of dicts with explicit "threshold".
|
|
569
|
+
Default: 0.0 (uses default 0.4 when not specified)
|
|
563
570
|
from_agents: Only consume artifacts from specific agents
|
|
564
571
|
tags: Only consume artifacts with matching tags
|
|
565
572
|
join: Join specification for coordinating multiple artifact types
|
|
566
573
|
batch: Batch specification for processing multiple artifacts together
|
|
567
|
-
|
|
568
|
-
mode: Processing mode - "both", "streaming", or "batch"
|
|
574
|
+
mode: Processing mode - "both", "direct", or "events"
|
|
569
575
|
priority: Execution priority (higher = executes first)
|
|
570
576
|
|
|
571
577
|
Returns:
|
|
@@ -587,6 +593,12 @@ class AgentBuilder:
|
|
|
587
593
|
... where=[lambda o: o.total > 100, lambda o: o.status == "pending"],
|
|
588
594
|
... )
|
|
589
595
|
|
|
596
|
+
>>> # Semantic matching
|
|
597
|
+
>>> agent.consumes(Ticket, semantic_match="security vulnerability")
|
|
598
|
+
|
|
599
|
+
>>> # Semantic matching with custom threshold
|
|
600
|
+
>>> agent.consumes(Ticket, semantic_match="urgent", semantic_threshold=0.6)
|
|
601
|
+
|
|
590
602
|
>>> # Consume from specific agents
|
|
591
603
|
>>> agent.consumes(Report, from_agents=["analyzer", "validator"])
|
|
592
604
|
|
|
@@ -607,17 +619,40 @@ class AgentBuilder:
|
|
|
607
619
|
# Phase 5B: Use BuilderValidator for normalization
|
|
608
620
|
join_spec = BuilderValidator.normalize_join(join)
|
|
609
621
|
batch_spec = BuilderValidator.normalize_batch(batch)
|
|
610
|
-
|
|
622
|
+
|
|
623
|
+
# Handle semantic_threshold parameter to control semantic matching threshold
|
|
624
|
+
# If semantic_threshold is provided and semantic_match is simple, convert to dict
|
|
625
|
+
semantic_param: (
|
|
626
|
+
str | list[str] | list[dict[str, Any]] | dict[str, Any] | None
|
|
627
|
+
) = semantic_match
|
|
628
|
+
if semantic_match is not None and semantic_threshold > 0.0:
|
|
629
|
+
if isinstance(semantic_match, str):
|
|
630
|
+
# Simple string: create dict with semantic_threshold as threshold
|
|
631
|
+
semantic_param = {
|
|
632
|
+
"query": semantic_match,
|
|
633
|
+
"threshold": semantic_threshold,
|
|
634
|
+
}
|
|
635
|
+
elif isinstance(semantic_match, list):
|
|
636
|
+
# List of strings: convert to list of dicts with semantic_threshold
|
|
637
|
+
semantic_param = [
|
|
638
|
+
{"query": q, "threshold": semantic_threshold}
|
|
639
|
+
for q in semantic_match
|
|
640
|
+
]
|
|
641
|
+
elif isinstance(semantic_match, dict) and "threshold" not in semantic_match:
|
|
642
|
+
# Dict without explicit threshold: add semantic_threshold
|
|
643
|
+
semantic_param = {**semantic_match, "threshold": semantic_threshold}
|
|
644
|
+
|
|
645
|
+
# Semantic matching: pass semantic_match parameter to Subscription
|
|
646
|
+
# which will parse it into TextPredicate objects
|
|
611
647
|
subscription = Subscription(
|
|
612
648
|
agent_name=self._agent.name,
|
|
613
649
|
types=types,
|
|
614
650
|
where=predicates,
|
|
615
|
-
|
|
651
|
+
semantic_match=semantic_param, # Let Subscription handle conversion
|
|
616
652
|
from_agents=from_agents,
|
|
617
653
|
tags=tags,
|
|
618
654
|
join=join_spec,
|
|
619
655
|
batch=batch_spec,
|
|
620
|
-
delivery=delivery,
|
|
621
656
|
mode=mode,
|
|
622
657
|
priority=priority,
|
|
623
658
|
)
|
flock/core/orchestrator.py
CHANGED
|
@@ -111,7 +111,7 @@ class Flock(metaclass=AutoTracedMeta):
|
|
|
111
111
|
# Patch litellm imports and setup logger
|
|
112
112
|
self._patch_litellm_proxy_imports()
|
|
113
113
|
self._logger = logging.getLogger(__name__)
|
|
114
|
-
self.model = model
|
|
114
|
+
self.model = model or os.getenv("DEFAULT_MODEL")
|
|
115
115
|
|
|
116
116
|
# Phase 3: Initialize all components using OrchestratorInitializer
|
|
117
117
|
components = OrchestratorInitializer.initialize_components(
|
|
@@ -168,10 +168,6 @@ class Flock(metaclass=AutoTracedMeta):
|
|
|
168
168
|
self._scheduler = AgentScheduler(self, self._component_runner)
|
|
169
169
|
self._artifact_manager = ArtifactManager(self, self.store, self._scheduler)
|
|
170
170
|
|
|
171
|
-
# Resolve model default
|
|
172
|
-
if not model:
|
|
173
|
-
self.model = os.getenv("DEFAULT_MODEL")
|
|
174
|
-
|
|
175
171
|
# Log initialization
|
|
176
172
|
self._logger.debug("Orchestrator initialized: components=[]")
|
|
177
173
|
|
|
@@ -496,13 +492,17 @@ class Flock(metaclass=AutoTracedMeta):
|
|
|
496
492
|
|
|
497
493
|
# Runtime --------------------------------------------------------------
|
|
498
494
|
|
|
499
|
-
async def run_until_idle(self) -> None:
|
|
495
|
+
async def run_until_idle(self, *, wait_for_input: bool = False) -> None:
|
|
500
496
|
"""Wait for all scheduled agent tasks to complete.
|
|
501
497
|
|
|
502
498
|
This method blocks until the blackboard reaches a stable state where no
|
|
503
499
|
agents are queued for execution. Essential for batch processing and ensuring
|
|
504
500
|
all agent cascades complete before continuing.
|
|
505
501
|
|
|
502
|
+
Args:
|
|
503
|
+
wait_for_input: If True, waits for user input before returning (default: False).
|
|
504
|
+
Useful for debugging or step-by-step execution.
|
|
505
|
+
|
|
506
506
|
Note:
|
|
507
507
|
Automatically resets circuit breaker counters and shuts down MCP connections
|
|
508
508
|
when idle. Used with publish() for event-driven workflows.
|
|
@@ -518,6 +518,12 @@ class Flock(metaclass=AutoTracedMeta):
|
|
|
518
518
|
>>> await flock.publish_many([task1, task2, task3])
|
|
519
519
|
>>> await flock.run_until_idle() # All tasks processed in parallel
|
|
520
520
|
|
|
521
|
+
>>> # Step-by-step execution with user prompts
|
|
522
|
+
>>> await flock.publish(task1)
|
|
523
|
+
>>> await flock.run_until_idle(wait_for_input=True) # Pauses for user input
|
|
524
|
+
>>> await flock.publish(task2)
|
|
525
|
+
>>> await flock.run_until_idle(wait_for_input=True) # Pauses again
|
|
526
|
+
|
|
521
527
|
See Also:
|
|
522
528
|
- publish(): Event-driven artifact publishing
|
|
523
529
|
- publish_many(): Batch publishing for parallel execution
|
|
@@ -557,6 +563,12 @@ class Flock(metaclass=AutoTracedMeta):
|
|
|
557
563
|
# Automatically shutdown MCP connections when idle
|
|
558
564
|
await self.shutdown(include_components=False)
|
|
559
565
|
|
|
566
|
+
# Wait for user input if requested
|
|
567
|
+
if wait_for_input:
|
|
568
|
+
# Use asyncio.to_thread to avoid blocking the event loop
|
|
569
|
+
# since input() is a blocking I/O operation
|
|
570
|
+
await asyncio.to_thread(input, "Press any key to continue....")
|
|
571
|
+
|
|
560
572
|
async def direct_invoke(
|
|
561
573
|
self, agent: Agent, inputs: Sequence[BaseModel | Mapping[str, Any] | Artifact]
|
|
562
574
|
) -> list[Artifact]:
|
flock/core/subscription.py
CHANGED
|
@@ -21,8 +21,17 @@ Predicate = Callable[[BaseModel], bool]
|
|
|
21
21
|
|
|
22
22
|
@dataclass
|
|
23
23
|
class TextPredicate:
|
|
24
|
-
text
|
|
25
|
-
|
|
24
|
+
"""Semantic text matching predicate.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
query: The semantic query text to match against
|
|
28
|
+
threshold: Minimum similarity score (0.0 to 1.0) to consider a match
|
|
29
|
+
field: Optional field name to extract from payload. If None, uses all text.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
query: str
|
|
33
|
+
threshold: float = 0.4 # Default threshold for semantic matching
|
|
34
|
+
field: str | None = None # Optional field to extract from payload
|
|
26
35
|
|
|
27
36
|
|
|
28
37
|
@dataclass
|
|
@@ -97,21 +106,21 @@ class Subscription:
|
|
|
97
106
|
def __init__(
|
|
98
107
|
self,
|
|
99
108
|
*,
|
|
100
|
-
agent_name: str,
|
|
109
|
+
agent_name: str | None = None,
|
|
101
110
|
types: Sequence[type[BaseModel]],
|
|
102
111
|
where: Sequence[Predicate] | None = None,
|
|
103
112
|
text_predicates: Sequence[TextPredicate] | None = None,
|
|
113
|
+
semantic_match: str | list[str | dict[str, Any]] | dict[str, Any] | None = None,
|
|
104
114
|
from_agents: Iterable[str] | None = None,
|
|
105
115
|
tags: Iterable[str] | None = None,
|
|
106
116
|
join: JoinSpec | None = None,
|
|
107
117
|
batch: BatchSpec | None = None,
|
|
108
|
-
delivery: str = "exclusive",
|
|
109
118
|
mode: str = "both",
|
|
110
119
|
priority: int = 0,
|
|
111
120
|
) -> None:
|
|
112
121
|
if not types:
|
|
113
122
|
raise ValueError("Subscription must declare at least one type.")
|
|
114
|
-
self.agent_name = agent_name
|
|
123
|
+
self.agent_name = agent_name or ""
|
|
115
124
|
self.type_models: list[type[BaseModel]] = list(types)
|
|
116
125
|
|
|
117
126
|
# Register all types and build counts (supports duplicates for count-based AND gates)
|
|
@@ -127,15 +136,62 @@ class Subscription:
|
|
|
127
136
|
self.type_counts[type_name] = self.type_counts.get(type_name, 0) + 1
|
|
128
137
|
|
|
129
138
|
self.where = list(where or [])
|
|
130
|
-
|
|
139
|
+
|
|
140
|
+
# Parse semantic_match parameter into TextPredicate objects
|
|
141
|
+
parsed_text_predicates = self._parse_semantic_match_parameter(semantic_match)
|
|
142
|
+
self.text_predicates = list(text_predicates or []) + parsed_text_predicates
|
|
143
|
+
|
|
131
144
|
self.from_agents = set(from_agents or [])
|
|
132
145
|
self.tags = set(tags or [])
|
|
133
146
|
self.join = join
|
|
134
147
|
self.batch = batch
|
|
135
|
-
self.delivery = delivery
|
|
136
148
|
self.mode = mode
|
|
137
149
|
self.priority = priority
|
|
138
150
|
|
|
151
|
+
def _parse_semantic_match_parameter(
|
|
152
|
+
self, semantic_match: str | list[str | dict[str, Any]] | dict[str, Any] | None
|
|
153
|
+
) -> list[TextPredicate]:
|
|
154
|
+
"""Parse the semantic_match parameter into TextPredicate objects.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
semantic_match: Can be:
|
|
158
|
+
- str: "query" → TextPredicate(query="query", threshold=0.4)
|
|
159
|
+
- list: ["q1", "q2"] → multiple TextPredicates (AND logic)
|
|
160
|
+
or [{"query": "q1", "threshold": 0.8}, ...] with explicit thresholds
|
|
161
|
+
- dict: {"query": "...", "threshold": 0.8, "field": "body"}
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
List of TextPredicate objects
|
|
165
|
+
"""
|
|
166
|
+
if semantic_match is None:
|
|
167
|
+
return []
|
|
168
|
+
|
|
169
|
+
if isinstance(semantic_match, str):
|
|
170
|
+
return [TextPredicate(query=semantic_match)]
|
|
171
|
+
|
|
172
|
+
if isinstance(semantic_match, list):
|
|
173
|
+
# Handle both list of strings and list of dicts
|
|
174
|
+
predicates = []
|
|
175
|
+
for item in semantic_match:
|
|
176
|
+
if isinstance(item, str):
|
|
177
|
+
predicates.append(TextPredicate(query=item))
|
|
178
|
+
elif isinstance(item, dict):
|
|
179
|
+
query = item.get("query", "")
|
|
180
|
+
threshold = item.get("threshold", 0.4)
|
|
181
|
+
field = item.get("field", None)
|
|
182
|
+
predicates.append(
|
|
183
|
+
TextPredicate(query=query, threshold=threshold, field=field)
|
|
184
|
+
)
|
|
185
|
+
return predicates
|
|
186
|
+
|
|
187
|
+
if isinstance(semantic_match, dict):
|
|
188
|
+
query = semantic_match.get("query", "")
|
|
189
|
+
threshold = semantic_match.get("threshold", 0.4) # Match dataclass default
|
|
190
|
+
field = semantic_match.get("field", None)
|
|
191
|
+
return [TextPredicate(query=query, threshold=threshold, field=field)]
|
|
192
|
+
|
|
193
|
+
return []
|
|
194
|
+
|
|
139
195
|
def accepts_direct(self) -> bool:
|
|
140
196
|
return self.mode in {"direct", "both"}
|
|
141
197
|
|
|
@@ -159,12 +215,99 @@ class Subscription:
|
|
|
159
215
|
return False
|
|
160
216
|
except Exception:
|
|
161
217
|
return False
|
|
218
|
+
|
|
219
|
+
# Evaluate text predicates using semantic matching
|
|
220
|
+
if self.text_predicates:
|
|
221
|
+
if not self._matches_text_predicates(artifact):
|
|
222
|
+
return False
|
|
223
|
+
|
|
162
224
|
return True
|
|
163
225
|
|
|
226
|
+
def _matches_text_predicates(self, artifact: Artifact) -> bool:
|
|
227
|
+
"""Check if artifact matches all text predicates (AND logic).
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
artifact: The artifact to check
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
bool: True if all text predicates match (or if semantic unavailable)
|
|
234
|
+
"""
|
|
235
|
+
# Check if semantic features available
|
|
236
|
+
try:
|
|
237
|
+
from flock.semantic import SEMANTIC_AVAILABLE, EmbeddingService
|
|
238
|
+
except ImportError:
|
|
239
|
+
# Graceful degradation - if semantic not available, skip text predicates
|
|
240
|
+
return True
|
|
241
|
+
|
|
242
|
+
if not SEMANTIC_AVAILABLE:
|
|
243
|
+
# Graceful degradation
|
|
244
|
+
return True
|
|
245
|
+
|
|
246
|
+
try:
|
|
247
|
+
embedding_service = EmbeddingService.get_instance()
|
|
248
|
+
except Exception:
|
|
249
|
+
# If embedding service fails, degrade gracefully
|
|
250
|
+
return True
|
|
251
|
+
|
|
252
|
+
# Extract text from artifact payload
|
|
253
|
+
artifact_text = self._extract_text_from_payload(artifact.payload)
|
|
254
|
+
if not artifact_text or not artifact_text.strip():
|
|
255
|
+
# No text to match against
|
|
256
|
+
return False
|
|
257
|
+
|
|
258
|
+
# Check all predicates (AND logic)
|
|
259
|
+
for predicate in self.text_predicates:
|
|
260
|
+
try:
|
|
261
|
+
# Extract text based on field specification
|
|
262
|
+
if predicate.field:
|
|
263
|
+
# Use specific field
|
|
264
|
+
text_to_match = str(artifact.payload.get(predicate.field, ""))
|
|
265
|
+
else:
|
|
266
|
+
# Use all text from payload
|
|
267
|
+
text_to_match = artifact_text
|
|
268
|
+
|
|
269
|
+
if not text_to_match or not text_to_match.strip():
|
|
270
|
+
return False
|
|
271
|
+
|
|
272
|
+
# Compute semantic similarity
|
|
273
|
+
similarity = embedding_service.similarity(
|
|
274
|
+
predicate.query, text_to_match
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Check threshold
|
|
278
|
+
if similarity < predicate.threshold:
|
|
279
|
+
return False
|
|
280
|
+
|
|
281
|
+
except Exception:
|
|
282
|
+
# If any error occurs, fail the match
|
|
283
|
+
return False
|
|
284
|
+
|
|
285
|
+
return True
|
|
286
|
+
|
|
287
|
+
def _extract_text_from_payload(self, payload: dict[str, Any]) -> str:
|
|
288
|
+
"""Extract all text content from payload.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
payload: The artifact payload dict
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
str: Concatenated text from all string fields
|
|
295
|
+
"""
|
|
296
|
+
text_parts = []
|
|
297
|
+
for value in payload.values():
|
|
298
|
+
if isinstance(value, str):
|
|
299
|
+
text_parts.append(value)
|
|
300
|
+
elif isinstance(value, (list, tuple)):
|
|
301
|
+
for item in value:
|
|
302
|
+
if isinstance(item, str):
|
|
303
|
+
text_parts.append(item)
|
|
304
|
+
|
|
305
|
+
return " ".join(text_parts)
|
|
306
|
+
|
|
164
307
|
def __repr__(self) -> str: # pragma: no cover - debug helper
|
|
165
308
|
return (
|
|
166
309
|
f"Subscription(agent={self.agent_name!r}, types={list(self.type_names)!r}, "
|
|
167
|
-
f"
|
|
310
|
+
f"mode={self.mode!r})"
|
|
168
311
|
)
|
|
169
312
|
|
|
170
313
|
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""Semantic subscriptions for Flock.
|
|
2
|
+
|
|
3
|
+
This module provides semantic matching capabilities using sentence-transformers.
|
|
4
|
+
It's an optional feature that requires installing the [semantic] extra:
|
|
5
|
+
|
|
6
|
+
uv add flock-core[semantic]
|
|
7
|
+
|
|
8
|
+
If sentence-transformers is not installed, semantic features will gracefully
|
|
9
|
+
degrade and core Flock functionality remains unaffected.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
# Try to import semantic features
|
|
13
|
+
try:
|
|
14
|
+
from sentence_transformers import SentenceTransformer # noqa: F401
|
|
15
|
+
|
|
16
|
+
from .context_provider import SemanticContextProvider
|
|
17
|
+
from .embedding_service import EmbeddingService
|
|
18
|
+
|
|
19
|
+
SEMANTIC_AVAILABLE = True
|
|
20
|
+
except ImportError as e:
|
|
21
|
+
SEMANTIC_AVAILABLE = False
|
|
22
|
+
_import_error = e
|
|
23
|
+
|
|
24
|
+
# Provide helpful error message when features are used
|
|
25
|
+
class EmbeddingService: # type: ignore
|
|
26
|
+
"""Placeholder when semantic extras not installed."""
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def get_instance(*args, **kwargs):
|
|
30
|
+
raise ImportError(
|
|
31
|
+
"Semantic features require sentence-transformers. "
|
|
32
|
+
"Install with: uv add flock-core[semantic]"
|
|
33
|
+
) from _import_error
|
|
34
|
+
|
|
35
|
+
class SemanticContextProvider: # type: ignore
|
|
36
|
+
"""Placeholder when semantic extras not installed."""
|
|
37
|
+
|
|
38
|
+
def __init__(self, *args, **kwargs):
|
|
39
|
+
raise ImportError(
|
|
40
|
+
"Semantic features require sentence-transformers. "
|
|
41
|
+
"Install with: uv add flock-core[semantic]"
|
|
42
|
+
) from _import_error
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
__all__ = [
|
|
46
|
+
"SEMANTIC_AVAILABLE",
|
|
47
|
+
"EmbeddingService",
|
|
48
|
+
"SemanticContextProvider",
|
|
49
|
+
]
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""Semantic context providers for agent execution.
|
|
2
|
+
|
|
3
|
+
This module provides context providers that use semantic similarity to find
|
|
4
|
+
relevant historical artifacts for agent context.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from typing import TYPE_CHECKING, Any
|
|
11
|
+
|
|
12
|
+
from pydantic import BaseModel
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from flock.core.artifacts import Artifact
|
|
17
|
+
from flock.core.store import ArtifactStore
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SemanticContextProvider:
|
|
21
|
+
"""Context provider that retrieves semantically relevant historical artifacts.
|
|
22
|
+
|
|
23
|
+
This provider uses semantic similarity to find artifacts that are relevant
|
|
24
|
+
to a given query text, enabling agents to make decisions based on similar
|
|
25
|
+
past events.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
query_text: The semantic query to match against artifacts
|
|
29
|
+
threshold: Minimum similarity score (0.0 to 1.0) to include in results
|
|
30
|
+
limit: Maximum number of artifacts to return
|
|
31
|
+
extract_field: Optional field name to extract from artifact payload for matching.
|
|
32
|
+
If None, uses all text from payload.
|
|
33
|
+
artifact_type: Optional type filter - only return artifacts of this type
|
|
34
|
+
where: Optional predicate filter for additional filtering
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
```python
|
|
38
|
+
provider = SemanticContextProvider(
|
|
39
|
+
query_text="user authentication issues", threshold=0.5, limit=5
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
relevant_artifacts = await provider.get_context(store)
|
|
43
|
+
```
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
query_text: str,
|
|
49
|
+
threshold: float = 0.4,
|
|
50
|
+
limit: int = 10,
|
|
51
|
+
extract_field: str | None = None,
|
|
52
|
+
artifact_type: type[BaseModel] | None = None,
|
|
53
|
+
where: Callable[[Artifact], bool] | None = None,
|
|
54
|
+
):
|
|
55
|
+
"""Initialize semantic context provider.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
query_text: The semantic query text
|
|
59
|
+
threshold: Minimum similarity score (default: 0.4)
|
|
60
|
+
limit: Maximum results to return (default: 10)
|
|
61
|
+
extract_field: Optional field to extract from payload
|
|
62
|
+
artifact_type: Optional type filter
|
|
63
|
+
where: Optional predicate for additional filtering
|
|
64
|
+
"""
|
|
65
|
+
if not query_text or not query_text.strip():
|
|
66
|
+
raise ValueError("query_text cannot be empty")
|
|
67
|
+
|
|
68
|
+
if not 0.0 <= threshold <= 1.0:
|
|
69
|
+
raise ValueError("threshold must be between 0 and 1")
|
|
70
|
+
|
|
71
|
+
if limit < 1:
|
|
72
|
+
raise ValueError("limit must be at least 1")
|
|
73
|
+
|
|
74
|
+
self.query_text = query_text
|
|
75
|
+
self.threshold = threshold
|
|
76
|
+
self.limit = limit
|
|
77
|
+
self.extract_field = extract_field
|
|
78
|
+
self.artifact_type = artifact_type
|
|
79
|
+
self.where = where
|
|
80
|
+
|
|
81
|
+
async def get_context(self, store: ArtifactStore) -> list[Artifact]:
|
|
82
|
+
"""Retrieve semantically relevant artifacts from store.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
store: The artifact store to query
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
List of relevant artifacts, sorted by similarity (highest first)
|
|
89
|
+
"""
|
|
90
|
+
# Check if semantic features available
|
|
91
|
+
try:
|
|
92
|
+
from flock.semantic import SEMANTIC_AVAILABLE, EmbeddingService
|
|
93
|
+
except ImportError:
|
|
94
|
+
return []
|
|
95
|
+
|
|
96
|
+
if not SEMANTIC_AVAILABLE:
|
|
97
|
+
return []
|
|
98
|
+
|
|
99
|
+
try:
|
|
100
|
+
embedding_service = EmbeddingService.get_instance()
|
|
101
|
+
except Exception:
|
|
102
|
+
return []
|
|
103
|
+
|
|
104
|
+
# Get query embedding
|
|
105
|
+
try:
|
|
106
|
+
query_embedding = embedding_service.embed(self.query_text)
|
|
107
|
+
except Exception:
|
|
108
|
+
return []
|
|
109
|
+
|
|
110
|
+
# Get all artifacts from store
|
|
111
|
+
all_artifacts = await store.list()
|
|
112
|
+
|
|
113
|
+
# Filter by type if specified
|
|
114
|
+
if self.artifact_type:
|
|
115
|
+
from flock.registry import type_registry
|
|
116
|
+
|
|
117
|
+
type_name = type_registry.register(self.artifact_type)
|
|
118
|
+
all_artifacts = [a for a in all_artifacts if a.type == type_name]
|
|
119
|
+
|
|
120
|
+
# Filter by where clause if specified
|
|
121
|
+
if self.where:
|
|
122
|
+
all_artifacts = [a for a in all_artifacts if self.where(a)]
|
|
123
|
+
|
|
124
|
+
# Compute similarities and filter
|
|
125
|
+
results: list[tuple[Artifact, float]] = []
|
|
126
|
+
|
|
127
|
+
for artifact in all_artifacts:
|
|
128
|
+
try:
|
|
129
|
+
# Extract text from artifact
|
|
130
|
+
if self.extract_field:
|
|
131
|
+
# Use specific field
|
|
132
|
+
text = str(artifact.payload.get(self.extract_field, ""))
|
|
133
|
+
else:
|
|
134
|
+
# Use all text from payload
|
|
135
|
+
text = self._extract_text_from_payload(artifact.payload)
|
|
136
|
+
|
|
137
|
+
if not text or not text.strip():
|
|
138
|
+
continue
|
|
139
|
+
|
|
140
|
+
# Compute similarity
|
|
141
|
+
similarity = embedding_service.similarity(self.query_text, text)
|
|
142
|
+
|
|
143
|
+
# Check threshold
|
|
144
|
+
if similarity >= self.threshold:
|
|
145
|
+
results.append((artifact, similarity))
|
|
146
|
+
|
|
147
|
+
except Exception:
|
|
148
|
+
# Skip artifacts that fail processing
|
|
149
|
+
continue
|
|
150
|
+
|
|
151
|
+
# Sort by similarity (highest first) and take top N
|
|
152
|
+
results.sort(key=lambda x: x[1], reverse=True)
|
|
153
|
+
return [artifact for artifact, _ in results[: self.limit]]
|
|
154
|
+
|
|
155
|
+
def _extract_text_from_payload(self, payload: dict[str, Any]) -> str:
|
|
156
|
+
"""Extract all text content from payload.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
payload: The artifact payload dict
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
str: Concatenated text from all string fields
|
|
163
|
+
"""
|
|
164
|
+
text_parts = []
|
|
165
|
+
for value in payload.values():
|
|
166
|
+
if isinstance(value, str):
|
|
167
|
+
text_parts.append(value)
|
|
168
|
+
elif isinstance(value, (list, tuple)):
|
|
169
|
+
for item in value:
|
|
170
|
+
if isinstance(item, str):
|
|
171
|
+
text_parts.append(item)
|
|
172
|
+
|
|
173
|
+
return " ".join(text_parts)
|