isage-tooluse 0.1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_tooluse-0.1.0.0.dist-info/METADATA +208 -0
- isage_tooluse-0.1.0.0.dist-info/RECORD +14 -0
- isage_tooluse-0.1.0.0.dist-info/WHEEL +5 -0
- isage_tooluse-0.1.0.0.dist-info/licenses/LICENSE +21 -0
- isage_tooluse-0.1.0.0.dist-info/top_level.txt +1 -0
- sage_libs/sage_tooluse/__init__.py +75 -0
- sage_libs/sage_tooluse/base.py +203 -0
- sage_libs/sage_tooluse/dfsdt_selector.py +402 -0
- sage_libs/sage_tooluse/embedding_selector.py +281 -0
- sage_libs/sage_tooluse/gorilla_selector.py +495 -0
- sage_libs/sage_tooluse/hybrid_selector.py +202 -0
- sage_libs/sage_tooluse/keyword_selector.py +270 -0
- sage_libs/sage_tooluse/registry.py +185 -0
- sage_libs/sage_tooluse/schemas.py +196 -0
|
@@ -0,0 +1,402 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DFSDT (Depth-First Search-based Decision Tree) Tool Selector.
|
|
3
|
+
|
|
4
|
+
Implementation based on ToolLLM paper (Qin et al., 2023):
|
|
5
|
+
"ToolLLM: Facilitating Large Language Models to Master 16000+ Real-world APIs"
|
|
6
|
+
|
|
7
|
+
The DFSDT algorithm treats tool selection as a tree search problem where:
|
|
8
|
+
1. Each node represents a candidate tool evaluation state
|
|
9
|
+
2. LLM is used to score tool relevance at each node
|
|
10
|
+
3. DFS explores promising paths first based on scores
|
|
11
|
+
4. Diversity prompting encourages exploration of different tool combinations
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import logging
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
from .base import BaseToolSelector, SelectorResources
|
|
19
|
+
from .keyword_selector import KeywordSelector
|
|
20
|
+
from .schemas import (
|
|
21
|
+
DFSDTSelectorConfig,
|
|
22
|
+
KeywordSelectorConfig,
|
|
23
|
+
SelectorConfig,
|
|
24
|
+
ToolPrediction,
|
|
25
|
+
ToolSelectionQuery,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# Prompt templates for LLM-based tool scoring
|
|
32
|
+
TOOL_RELEVANCE_PROMPT = """You are a tool selection expert. Given a user query and a candidate tool,
|
|
33
|
+
evaluate how relevant the tool is for completing the query.
|
|
34
|
+
|
|
35
|
+
User Query: {query}
|
|
36
|
+
|
|
37
|
+
Candidate Tool:
|
|
38
|
+
- Name: {tool_name}
|
|
39
|
+
- Description: {tool_description}
|
|
40
|
+
- Capabilities: {tool_capabilities}
|
|
41
|
+
|
|
42
|
+
Rate the relevance of this tool for the given query on a scale of 0 to 10, where:
|
|
43
|
+
- 0-2: Not relevant at all
|
|
44
|
+
- 3-4: Slightly relevant, might be useful indirectly
|
|
45
|
+
- 5-6: Moderately relevant, could help with part of the task
|
|
46
|
+
- 7-8: Highly relevant, directly addresses the query
|
|
47
|
+
- 9-10: Perfect match, exactly what's needed
|
|
48
|
+
|
|
49
|
+
Provide your rating as a single number. Only output the number, nothing else.
|
|
50
|
+
|
|
51
|
+
Rating:"""
|
|
52
|
+
|
|
53
|
+
DIVERSITY_PROMPT = """This is not the first time evaluating tools for this query.
|
|
54
|
+
Previous highly-rated tools were: {previous_tools}
|
|
55
|
+
|
|
56
|
+
Now evaluate a different tool that might provide a complementary or alternative approach.
|
|
57
|
+
Consider tools that:
|
|
58
|
+
1. Address different aspects of the query
|
|
59
|
+
2. Provide backup options if primary tools fail
|
|
60
|
+
3. Offer different methods to achieve the same goal
|
|
61
|
+
|
|
62
|
+
{base_prompt}"""
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass
|
|
66
|
+
class SearchNode:
|
|
67
|
+
"""Node in the DFSDT search tree."""
|
|
68
|
+
|
|
69
|
+
tool_id: str
|
|
70
|
+
tool_name: str
|
|
71
|
+
tool_description: str
|
|
72
|
+
score: float = 0.0
|
|
73
|
+
depth: int = 0
|
|
74
|
+
parent: Optional["SearchNode"] = None
|
|
75
|
+
children: list["SearchNode"] = field(default_factory=list)
|
|
76
|
+
visited: bool = False
|
|
77
|
+
pruned: bool = False
|
|
78
|
+
|
|
79
|
+
def __hash__(self):
|
|
80
|
+
return hash(self.tool_id)
|
|
81
|
+
|
|
82
|
+
def __eq__(self, other):
|
|
83
|
+
if isinstance(other, SearchNode):
|
|
84
|
+
return self.tool_id == other.tool_id
|
|
85
|
+
return False
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class DFSDTSelector(BaseToolSelector):
|
|
89
|
+
"""
|
|
90
|
+
DFSDT (Depth-First Search-based Decision Tree) tool selector.
|
|
91
|
+
|
|
92
|
+
This selector implements the core idea from ToolLLM:
|
|
93
|
+
1. Pre-filter candidates using fast keyword matching (optional)
|
|
94
|
+
2. Build a search tree with candidate tools as nodes
|
|
95
|
+
3. Use LLM to score each tool's relevance to the query
|
|
96
|
+
4. DFS traversal prioritizes high-scoring branches
|
|
97
|
+
5. Diversity prompting explores alternative tool combinations
|
|
98
|
+
|
|
99
|
+
Key Features:
|
|
100
|
+
- LLM-guided scoring for semantic understanding
|
|
101
|
+
- Tree search for exploring multiple tool combinations
|
|
102
|
+
- Diversity mechanism to avoid local optima
|
|
103
|
+
- Keyword pre-filtering for efficiency with large tool sets
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def __init__(self, config: DFSDTSelectorConfig, resources: SelectorResources):
|
|
107
|
+
"""
|
|
108
|
+
Initialize DFSDT selector.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
config: DFSDT selector configuration
|
|
112
|
+
resources: Shared resources including tools loader
|
|
113
|
+
"""
|
|
114
|
+
super().__init__(config, resources)
|
|
115
|
+
self.config: DFSDTSelectorConfig = config
|
|
116
|
+
|
|
117
|
+
# LLM client for scoring
|
|
118
|
+
self._llm_client = None
|
|
119
|
+
self._llm_initialized = False
|
|
120
|
+
|
|
121
|
+
# Keyword selector for pre-filtering
|
|
122
|
+
self._keyword_selector: Optional[KeywordSelector] = None
|
|
123
|
+
if config.use_keyword_prefilter:
|
|
124
|
+
keyword_config = KeywordSelectorConfig(
|
|
125
|
+
name="keyword_prefilter",
|
|
126
|
+
method="bm25",
|
|
127
|
+
top_k=config.prefilter_k,
|
|
128
|
+
)
|
|
129
|
+
self._keyword_selector = KeywordSelector(keyword_config, resources)
|
|
130
|
+
|
|
131
|
+
# Cache for tool metadata
|
|
132
|
+
self._tool_cache: dict[str, dict] = {}
|
|
133
|
+
self._preload_tools()
|
|
134
|
+
|
|
135
|
+
@classmethod
|
|
136
|
+
def from_config(cls, config: SelectorConfig, resources: SelectorResources) -> "DFSDTSelector":
|
|
137
|
+
"""Create DFSDT selector from config."""
|
|
138
|
+
if not isinstance(config, DFSDTSelectorConfig):
|
|
139
|
+
config = DFSDTSelectorConfig(
|
|
140
|
+
name=config.name,
|
|
141
|
+
top_k=config.top_k,
|
|
142
|
+
min_score=config.min_score,
|
|
143
|
+
cache_enabled=config.cache_enabled,
|
|
144
|
+
params=config.params,
|
|
145
|
+
)
|
|
146
|
+
return cls(config, resources)
|
|
147
|
+
|
|
148
|
+
def _preload_tools(self) -> None:
|
|
149
|
+
"""Preload tool metadata into cache."""
|
|
150
|
+
try:
|
|
151
|
+
tools_loader = self.resources.tools_loader
|
|
152
|
+
for tool in tools_loader.iter_all():
|
|
153
|
+
self._tool_cache[tool.tool_id] = {
|
|
154
|
+
"name": tool.name,
|
|
155
|
+
"description": getattr(tool, "description", "") or "",
|
|
156
|
+
"capabilities": getattr(tool, "capabilities", []) or [],
|
|
157
|
+
"category": getattr(tool, "category", "") or "",
|
|
158
|
+
}
|
|
159
|
+
self.logger.info(f"DFSDT: Preloaded {len(self._tool_cache)} tools")
|
|
160
|
+
except Exception as e:
|
|
161
|
+
self.logger.error(f"Error preloading tools: {e}")
|
|
162
|
+
|
|
163
|
+
def _get_llm_client(self):
|
|
164
|
+
"""Lazy initialization of LLM client."""
|
|
165
|
+
if not self._llm_initialized:
|
|
166
|
+
try:
|
|
167
|
+
from sage.llm import UnifiedInferenceClient
|
|
168
|
+
|
|
169
|
+
self._llm_client = UnifiedInferenceClient.create()
|
|
170
|
+
self._llm_initialized = True
|
|
171
|
+
self.logger.info("DFSDT: LLM client initialized")
|
|
172
|
+
except Exception as e:
|
|
173
|
+
self.logger.warning(f"Could not initialize LLM client: {e}")
|
|
174
|
+
self._llm_client = None
|
|
175
|
+
self._llm_initialized = True
|
|
176
|
+
|
|
177
|
+
return self._llm_client
|
|
178
|
+
|
|
179
|
+
def _select_impl(self, query: ToolSelectionQuery, top_k: int) -> list[ToolPrediction]:
|
|
180
|
+
"""
|
|
181
|
+
Select tools using DFSDT algorithm.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
query: Tool selection query
|
|
185
|
+
top_k: Number of tools to select
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
List of tool predictions with scores
|
|
189
|
+
"""
|
|
190
|
+
# Step 1: Get candidate tools
|
|
191
|
+
if query.candidate_tools:
|
|
192
|
+
candidate_ids = set(query.candidate_tools)
|
|
193
|
+
else:
|
|
194
|
+
candidate_ids = set(self._tool_cache.keys())
|
|
195
|
+
|
|
196
|
+
# Step 2: Pre-filter using keyword matching if enabled
|
|
197
|
+
if self._keyword_selector and len(candidate_ids) > self.config.prefilter_k:
|
|
198
|
+
prefilter_results = self._keyword_selector._select_impl(query, self.config.prefilter_k)
|
|
199
|
+
candidate_ids = {p.tool_id for p in prefilter_results}
|
|
200
|
+
self.logger.debug(f"DFSDT: Pre-filtered to {len(candidate_ids)} candidates")
|
|
201
|
+
|
|
202
|
+
# Step 3: Build search tree and run DFSDT
|
|
203
|
+
results = self._dfsdt_search(query, candidate_ids, top_k)
|
|
204
|
+
|
|
205
|
+
return results
|
|
206
|
+
|
|
207
|
+
def _dfsdt_search(
|
|
208
|
+
self, query: ToolSelectionQuery, candidate_ids: set[str], top_k: int
|
|
209
|
+
) -> list[ToolPrediction]:
|
|
210
|
+
"""
|
|
211
|
+
Run DFSDT search algorithm.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
query: Tool selection query
|
|
215
|
+
candidate_ids: Set of candidate tool IDs
|
|
216
|
+
top_k: Number of tools to select
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
List of tool predictions sorted by score
|
|
220
|
+
"""
|
|
221
|
+
# Create root node
|
|
222
|
+
root = SearchNode(
|
|
223
|
+
tool_id="__root__",
|
|
224
|
+
tool_name="root",
|
|
225
|
+
tool_description="Search tree root",
|
|
226
|
+
depth=0,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Create child nodes for each candidate
|
|
230
|
+
for tool_id in candidate_ids:
|
|
231
|
+
if tool_id not in self._tool_cache:
|
|
232
|
+
continue
|
|
233
|
+
|
|
234
|
+
tool_info = self._tool_cache[tool_id]
|
|
235
|
+
node = SearchNode(
|
|
236
|
+
tool_id=tool_id,
|
|
237
|
+
tool_name=tool_info["name"],
|
|
238
|
+
tool_description=tool_info["description"],
|
|
239
|
+
depth=1,
|
|
240
|
+
parent=root,
|
|
241
|
+
)
|
|
242
|
+
root.children.append(node)
|
|
243
|
+
|
|
244
|
+
# Score all nodes using LLM or fallback
|
|
245
|
+
scored_nodes: list[SearchNode] = []
|
|
246
|
+
visited_tools: list[str] = []
|
|
247
|
+
|
|
248
|
+
for node in root.children:
|
|
249
|
+
score = self._score_tool(query, node, visited_tools)
|
|
250
|
+
node.score = score
|
|
251
|
+
|
|
252
|
+
if score >= self.config.score_threshold:
|
|
253
|
+
scored_nodes.append(node)
|
|
254
|
+
visited_tools.append(node.tool_name)
|
|
255
|
+
|
|
256
|
+
# Sort by score (DFS prioritizes high scores)
|
|
257
|
+
scored_nodes.sort(key=lambda n: n.score, reverse=True)
|
|
258
|
+
|
|
259
|
+
# Take top-k results
|
|
260
|
+
predictions = []
|
|
261
|
+
for node in scored_nodes[:top_k]:
|
|
262
|
+
predictions.append(
|
|
263
|
+
ToolPrediction(
|
|
264
|
+
tool_id=node.tool_id,
|
|
265
|
+
score=min(node.score, 1.0), # Normalize to [0, 1]
|
|
266
|
+
metadata={
|
|
267
|
+
"method": "dfsdt",
|
|
268
|
+
"tool_name": node.tool_name,
|
|
269
|
+
"depth": node.depth,
|
|
270
|
+
},
|
|
271
|
+
)
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
return predictions
|
|
275
|
+
|
|
276
|
+
def _score_tool(
|
|
277
|
+
self,
|
|
278
|
+
query: ToolSelectionQuery,
|
|
279
|
+
node: SearchNode,
|
|
280
|
+
visited_tools: list[str],
|
|
281
|
+
) -> float:
|
|
282
|
+
"""
|
|
283
|
+
Score a tool's relevance using LLM.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
query: Tool selection query
|
|
287
|
+
node: Search node representing the tool
|
|
288
|
+
visited_tools: List of already scored tool names (for diversity)
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
Relevance score (0-1)
|
|
292
|
+
"""
|
|
293
|
+
llm_client = self._get_llm_client()
|
|
294
|
+
|
|
295
|
+
if llm_client is None:
|
|
296
|
+
# Fallback to simple keyword-based scoring
|
|
297
|
+
return self._fallback_score(query, node)
|
|
298
|
+
|
|
299
|
+
try:
|
|
300
|
+
# Build prompt
|
|
301
|
+
tool_info = self._tool_cache.get(node.tool_id, {})
|
|
302
|
+
capabilities = tool_info.get("capabilities", [])
|
|
303
|
+
cap_str = ", ".join(capabilities) if capabilities else "N/A"
|
|
304
|
+
|
|
305
|
+
base_prompt = TOOL_RELEVANCE_PROMPT.format(
|
|
306
|
+
query=query.instruction,
|
|
307
|
+
tool_name=node.tool_name,
|
|
308
|
+
tool_description=node.tool_description or "No description",
|
|
309
|
+
tool_capabilities=cap_str,
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
# Add diversity prompt if enabled and there are visited tools
|
|
313
|
+
if self.config.use_diversity_prompt and visited_tools:
|
|
314
|
+
prompt = DIVERSITY_PROMPT.format(
|
|
315
|
+
previous_tools=", ".join(visited_tools[-3:]), # Last 3 tools
|
|
316
|
+
base_prompt=base_prompt,
|
|
317
|
+
)
|
|
318
|
+
else:
|
|
319
|
+
prompt = base_prompt
|
|
320
|
+
|
|
321
|
+
# Call LLM
|
|
322
|
+
response = llm_client.chat(
|
|
323
|
+
[{"role": "user", "content": prompt}],
|
|
324
|
+
temperature=self.config.temperature,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# Parse score from response
|
|
328
|
+
score = self._parse_score(response)
|
|
329
|
+
return score / 10.0 # Normalize to 0-1
|
|
330
|
+
|
|
331
|
+
except Exception as e:
|
|
332
|
+
self.logger.warning(f"LLM scoring failed for {node.tool_id}: {e}")
|
|
333
|
+
return self._fallback_score(query, node)
|
|
334
|
+
|
|
335
|
+
def _fallback_score(self, query: ToolSelectionQuery, node: SearchNode) -> float:
|
|
336
|
+
"""
|
|
337
|
+
Fallback scoring using keyword matching when LLM is unavailable.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
query: Tool selection query
|
|
341
|
+
node: Search node representing the tool
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
Relevance score (0-1)
|
|
345
|
+
"""
|
|
346
|
+
query_lower = query.instruction.lower()
|
|
347
|
+
tool_text = f"{node.tool_name} {node.tool_description}".lower()
|
|
348
|
+
|
|
349
|
+
# Simple keyword overlap
|
|
350
|
+
query_words = set(query_lower.split())
|
|
351
|
+
tool_words = set(tool_text.split())
|
|
352
|
+
|
|
353
|
+
if not query_words or not tool_words:
|
|
354
|
+
return 0.0
|
|
355
|
+
|
|
356
|
+
overlap = len(query_words & tool_words)
|
|
357
|
+
score = overlap / len(query_words)
|
|
358
|
+
|
|
359
|
+
return min(score, 1.0)
|
|
360
|
+
|
|
361
|
+
def _parse_score(self, response: str) -> float:
|
|
362
|
+
"""
|
|
363
|
+
Parse numeric score from LLM response.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
response: LLM response string
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
Parsed score (0-10)
|
|
370
|
+
"""
|
|
371
|
+
import re
|
|
372
|
+
|
|
373
|
+
# Try to extract first number from response
|
|
374
|
+
numbers = re.findall(r"(\d+(?:\.\d+)?)", response.strip())
|
|
375
|
+
if numbers:
|
|
376
|
+
score = float(numbers[0])
|
|
377
|
+
return min(max(score, 0.0), 10.0) # Clamp to 0-10
|
|
378
|
+
|
|
379
|
+
# If no number found, try to infer from keywords
|
|
380
|
+
response_lower = response.lower()
|
|
381
|
+
if any(word in response_lower for word in ["perfect", "excellent", "exactly"]):
|
|
382
|
+
return 9.0
|
|
383
|
+
elif any(word in response_lower for word in ["highly", "very relevant", "good"]):
|
|
384
|
+
return 7.0
|
|
385
|
+
elif any(word in response_lower for word in ["moderate", "somewhat", "partial"]):
|
|
386
|
+
return 5.0
|
|
387
|
+
elif any(word in response_lower for word in ["slight", "minimal", "limited"]):
|
|
388
|
+
return 3.0
|
|
389
|
+
else:
|
|
390
|
+
return 1.0
|
|
391
|
+
|
|
392
|
+
def get_stats(self) -> dict:
|
|
393
|
+
"""Get selector statistics."""
|
|
394
|
+
stats = super().get_stats()
|
|
395
|
+
stats.update(
|
|
396
|
+
{
|
|
397
|
+
"llm_initialized": self._llm_initialized,
|
|
398
|
+
"tool_cache_size": len(self._tool_cache),
|
|
399
|
+
"has_keyword_prefilter": self._keyword_selector is not None,
|
|
400
|
+
}
|
|
401
|
+
)
|
|
402
|
+
return stats
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Embedding-based tool selector.
|
|
3
|
+
|
|
4
|
+
Uses embedding models and vector similarity search for tool selection.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from .base import BaseToolSelector, SelectorResources
|
|
13
|
+
from .retriever.vector_index import VectorIndex
|
|
14
|
+
from .schemas import (
|
|
15
|
+
EmbeddingSelectorConfig,
|
|
16
|
+
SelectorConfig,
|
|
17
|
+
ToolPrediction,
|
|
18
|
+
ToolSelectionQuery,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class EmbeddingSelector(BaseToolSelector):
|
|
25
|
+
"""
|
|
26
|
+
Embedding-based tool selector using vector similarity.
|
|
27
|
+
|
|
28
|
+
Uses embedding service to encode queries and tools, then performs
|
|
29
|
+
similarity search using VectorIndex.
|
|
30
|
+
|
|
31
|
+
Optimized for large-scale tool selection (1000+ tools).
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, config: EmbeddingSelectorConfig, resources: SelectorResources):
|
|
35
|
+
"""
|
|
36
|
+
Initialize embedding selector.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
config: Embedding selector configuration
|
|
40
|
+
resources: Shared resources including embedding_client
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
ValueError: If embedding_client is not provided in resources
|
|
44
|
+
"""
|
|
45
|
+
super().__init__(config, resources)
|
|
46
|
+
self.config: EmbeddingSelectorConfig = config
|
|
47
|
+
|
|
48
|
+
# Validate embedding client
|
|
49
|
+
if not resources.embedding_client:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"EmbeddingSelector requires embedding_client in SelectorResources. "
|
|
52
|
+
"Please provide an EmbeddingService instance."
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
self.embedding_client = resources.embedding_client
|
|
56
|
+
|
|
57
|
+
# Initialize vector index
|
|
58
|
+
self._index: Optional[VectorIndex] = None
|
|
59
|
+
self._tool_texts: dict[str, str] = {}
|
|
60
|
+
self._embedding_dimension: Optional[int] = None
|
|
61
|
+
|
|
62
|
+
# Preprocess tools
|
|
63
|
+
self._preprocess_tools()
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def from_config(
|
|
67
|
+
cls, config: SelectorConfig, resources: SelectorResources
|
|
68
|
+
) -> "EmbeddingSelector":
|
|
69
|
+
"""Create embedding selector from config."""
|
|
70
|
+
if not isinstance(config, EmbeddingSelectorConfig):
|
|
71
|
+
raise TypeError(f"Expected EmbeddingSelectorConfig, got {type(config).__name__}")
|
|
72
|
+
return cls(config, resources) # type: ignore[arg-type]
|
|
73
|
+
|
|
74
|
+
def _preprocess_tools(self) -> None:
|
|
75
|
+
"""Preprocess all tools and build vector index."""
|
|
76
|
+
try:
|
|
77
|
+
tools_loader = self.resources.tools_loader
|
|
78
|
+
|
|
79
|
+
# Collect all tool texts
|
|
80
|
+
tool_ids = []
|
|
81
|
+
tool_texts = []
|
|
82
|
+
|
|
83
|
+
for tool in tools_loader.iter_all():
|
|
84
|
+
text = self._build_tool_text(tool)
|
|
85
|
+
self._tool_texts[tool.tool_id] = text
|
|
86
|
+
tool_ids.append(tool.tool_id)
|
|
87
|
+
tool_texts.append(text)
|
|
88
|
+
|
|
89
|
+
if not tool_texts:
|
|
90
|
+
self.logger.warning("No tools found to preprocess")
|
|
91
|
+
return
|
|
92
|
+
|
|
93
|
+
self.logger.info(f"Embedding {len(tool_texts)} tools...")
|
|
94
|
+
|
|
95
|
+
# Embed all tools in batch
|
|
96
|
+
embeddings = self._embed_texts(tool_texts)
|
|
97
|
+
|
|
98
|
+
# Infer dimension from embeddings
|
|
99
|
+
self._embedding_dimension = embeddings.shape[1]
|
|
100
|
+
|
|
101
|
+
# Build vector index
|
|
102
|
+
self._index = VectorIndex(
|
|
103
|
+
dimension=self._embedding_dimension, metric=self.config.similarity_metric
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Add all vectors to index
|
|
107
|
+
self._index.add_batch(
|
|
108
|
+
vector_ids=tool_ids,
|
|
109
|
+
vectors=embeddings,
|
|
110
|
+
metadata=[{"text": text} for text in tool_texts],
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
self.logger.info(
|
|
114
|
+
f"Built vector index with {len(tool_ids)} tools "
|
|
115
|
+
f"(dimension={self._embedding_dimension}, metric={self.config.similarity_metric})"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
except Exception as e:
|
|
119
|
+
self.logger.error(f"Error preprocessing tools: {e}")
|
|
120
|
+
raise
|
|
121
|
+
|
|
122
|
+
def _build_tool_text(self, tool) -> str:
|
|
123
|
+
"""
|
|
124
|
+
Build searchable text from tool metadata.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
tool: Tool object with metadata
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Concatenated text representation
|
|
131
|
+
"""
|
|
132
|
+
parts = [tool.name]
|
|
133
|
+
|
|
134
|
+
if hasattr(tool, "description") and tool.description:
|
|
135
|
+
parts.append(tool.description)
|
|
136
|
+
|
|
137
|
+
if hasattr(tool, "capabilities") and tool.capabilities:
|
|
138
|
+
if isinstance(tool.capabilities, list):
|
|
139
|
+
parts.extend(tool.capabilities)
|
|
140
|
+
else:
|
|
141
|
+
parts.append(str(tool.capabilities))
|
|
142
|
+
|
|
143
|
+
if hasattr(tool, "category") and tool.category:
|
|
144
|
+
parts.append(tool.category)
|
|
145
|
+
|
|
146
|
+
# Include parameter descriptions if available
|
|
147
|
+
if hasattr(tool, "parameters") and tool.parameters:
|
|
148
|
+
if isinstance(tool.parameters, dict):
|
|
149
|
+
for param_name, param_info in tool.parameters.items():
|
|
150
|
+
if isinstance(param_info, dict) and "description" in param_info:
|
|
151
|
+
parts.append(f"{param_name}: {param_info['description']}")
|
|
152
|
+
|
|
153
|
+
return " ".join(parts)
|
|
154
|
+
|
|
155
|
+
def _embed_texts(self, texts: list[str]) -> np.ndarray:
|
|
156
|
+
"""
|
|
157
|
+
Embed texts using embedding client.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
texts: List of texts to embed
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
Array of embeddings (shape: N x D)
|
|
164
|
+
"""
|
|
165
|
+
try:
|
|
166
|
+
# Use embedding client to embed texts
|
|
167
|
+
# The client should handle batching internally
|
|
168
|
+
embeddings = self.embedding_client.embed(
|
|
169
|
+
texts=texts,
|
|
170
|
+
model=self.config.embedding_model
|
|
171
|
+
if self.config.embedding_model != "default"
|
|
172
|
+
else None,
|
|
173
|
+
batch_size=self.config.batch_size,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Convert to numpy array if not already
|
|
177
|
+
embeddings = np.asarray(embeddings)
|
|
178
|
+
|
|
179
|
+
# Ensure 2D shape
|
|
180
|
+
if embeddings.ndim == 1:
|
|
181
|
+
embeddings = embeddings.reshape(1, -1)
|
|
182
|
+
|
|
183
|
+
return embeddings
|
|
184
|
+
|
|
185
|
+
except Exception as e:
|
|
186
|
+
self.logger.error(f"Error embedding texts: {e}")
|
|
187
|
+
raise
|
|
188
|
+
|
|
189
|
+
def _select_impl(self, query: ToolSelectionQuery, top_k: int) -> list[ToolPrediction]:
|
|
190
|
+
"""
|
|
191
|
+
Select tools using embedding similarity.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
query: Tool selection query
|
|
195
|
+
top_k: Number of tools to select
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
List of tool predictions sorted by similarity
|
|
199
|
+
"""
|
|
200
|
+
if self._index is None:
|
|
201
|
+
self.logger.error("Vector index not initialized")
|
|
202
|
+
return []
|
|
203
|
+
|
|
204
|
+
try:
|
|
205
|
+
# Embed query
|
|
206
|
+
query_embedding = self._embed_texts([query.instruction])
|
|
207
|
+
query_vector = query_embedding[0] # Get single vector
|
|
208
|
+
|
|
209
|
+
# Filter candidates if specified
|
|
210
|
+
candidate_ids = None
|
|
211
|
+
if query.candidate_tools:
|
|
212
|
+
candidate_ids = set(query.candidate_tools)
|
|
213
|
+
# Only search among candidates that exist in index
|
|
214
|
+
candidate_ids = candidate_ids & set(self._index._ids)
|
|
215
|
+
if not candidate_ids:
|
|
216
|
+
self.logger.warning(f"No valid candidates for query {query.sample_id}")
|
|
217
|
+
return []
|
|
218
|
+
|
|
219
|
+
# Search vector index
|
|
220
|
+
# Returns list of (vector_id, score) tuples
|
|
221
|
+
results = self._index.search(
|
|
222
|
+
query_vector=query_vector,
|
|
223
|
+
top_k=top_k,
|
|
224
|
+
filter_ids=list(candidate_ids) if candidate_ids else None,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Convert to ToolPrediction format
|
|
228
|
+
predictions = []
|
|
229
|
+
for tool_id, score in results:
|
|
230
|
+
# Normalize score to [0, 1] range
|
|
231
|
+
# For euclidean, VectorIndex returns negative distances
|
|
232
|
+
# Convert to similarity: higher is better
|
|
233
|
+
if self.config.similarity_metric == "euclidean":
|
|
234
|
+
# Negative distance -> convert to positive similarity
|
|
235
|
+
score = max(0.0, 1.0 / (1.0 + abs(score)))
|
|
236
|
+
else:
|
|
237
|
+
# Cosine and dot are already similarities
|
|
238
|
+
score = max(0.0, min(1.0, float(score)))
|
|
239
|
+
|
|
240
|
+
predictions.append(
|
|
241
|
+
ToolPrediction(
|
|
242
|
+
tool_id=tool_id,
|
|
243
|
+
score=score,
|
|
244
|
+
metadata={
|
|
245
|
+
"similarity_metric": self.config.similarity_metric,
|
|
246
|
+
"embedding_model": self.config.embedding_model,
|
|
247
|
+
},
|
|
248
|
+
)
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# Filter by minimum similarity threshold if specified
|
|
252
|
+
if hasattr(self.config, "similarity_threshold"):
|
|
253
|
+
threshold = self.config.similarity_threshold
|
|
254
|
+
predictions = [p for p in predictions if p.score >= threshold]
|
|
255
|
+
|
|
256
|
+
return predictions
|
|
257
|
+
|
|
258
|
+
except Exception as e:
|
|
259
|
+
self.logger.error(f"Error in embedding selection for {query.sample_id}: {e}")
|
|
260
|
+
raise
|
|
261
|
+
|
|
262
|
+
def get_embedding_dimension(self) -> Optional[int]:
|
|
263
|
+
"""Get embedding dimension."""
|
|
264
|
+
return self._embedding_dimension
|
|
265
|
+
|
|
266
|
+
def get_index_size(self) -> int:
|
|
267
|
+
"""Get number of tools in index."""
|
|
268
|
+
return len(self._index._ids) if self._index else 0
|
|
269
|
+
|
|
270
|
+
def get_stats(self) -> dict:
|
|
271
|
+
"""Get selector statistics."""
|
|
272
|
+
stats = super().get_stats()
|
|
273
|
+
stats.update(
|
|
274
|
+
{
|
|
275
|
+
"embedding_dimension": self._embedding_dimension,
|
|
276
|
+
"index_size": self.get_index_size(),
|
|
277
|
+
"similarity_metric": self.config.similarity_metric,
|
|
278
|
+
"embedding_model": self.config.embedding_model,
|
|
279
|
+
}
|
|
280
|
+
)
|
|
281
|
+
return stats
|