isage-tooluse 0.1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,402 @@
1
+ """
2
+ DFSDT (Depth-First Search-based Decision Tree) Tool Selector.
3
+
4
+ Implementation based on ToolLLM paper (Qin et al., 2023):
5
+ "ToolLLM: Facilitating Large Language Models to Master 16000+ Real-world APIs"
6
+
7
+ The DFSDT algorithm treats tool selection as a tree search problem where:
8
+ 1. Each node represents a candidate tool evaluation state
9
+ 2. LLM is used to score tool relevance at each node
10
+ 3. DFS explores promising paths first based on scores
11
+ 4. Diversity prompting encourages exploration of different tool combinations
12
+ """
13
+
14
+ import logging
15
+ from dataclasses import dataclass, field
16
+ from typing import Optional
17
+
18
+ from .base import BaseToolSelector, SelectorResources
19
+ from .keyword_selector import KeywordSelector
20
+ from .schemas import (
21
+ DFSDTSelectorConfig,
22
+ KeywordSelectorConfig,
23
+ SelectorConfig,
24
+ ToolPrediction,
25
+ ToolSelectionQuery,
26
+ )
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ # Prompt templates for LLM-based tool scoring
32
+ TOOL_RELEVANCE_PROMPT = """You are a tool selection expert. Given a user query and a candidate tool,
33
+ evaluate how relevant the tool is for completing the query.
34
+
35
+ User Query: {query}
36
+
37
+ Candidate Tool:
38
+ - Name: {tool_name}
39
+ - Description: {tool_description}
40
+ - Capabilities: {tool_capabilities}
41
+
42
+ Rate the relevance of this tool for the given query on a scale of 0 to 10, where:
43
+ - 0-2: Not relevant at all
44
+ - 3-4: Slightly relevant, might be useful indirectly
45
+ - 5-6: Moderately relevant, could help with part of the task
46
+ - 7-8: Highly relevant, directly addresses the query
47
+ - 9-10: Perfect match, exactly what's needed
48
+
49
+ Provide your rating as a single number. Only output the number, nothing else.
50
+
51
+ Rating:"""
52
+
53
+ DIVERSITY_PROMPT = """This is not the first time evaluating tools for this query.
54
+ Previous highly-rated tools were: {previous_tools}
55
+
56
+ Now evaluate a different tool that might provide a complementary or alternative approach.
57
+ Consider tools that:
58
+ 1. Address different aspects of the query
59
+ 2. Provide backup options if primary tools fail
60
+ 3. Offer different methods to achieve the same goal
61
+
62
+ {base_prompt}"""
63
+
64
+
65
+ @dataclass
66
+ class SearchNode:
67
+ """Node in the DFSDT search tree."""
68
+
69
+ tool_id: str
70
+ tool_name: str
71
+ tool_description: str
72
+ score: float = 0.0
73
+ depth: int = 0
74
+ parent: Optional["SearchNode"] = None
75
+ children: list["SearchNode"] = field(default_factory=list)
76
+ visited: bool = False
77
+ pruned: bool = False
78
+
79
+ def __hash__(self):
80
+ return hash(self.tool_id)
81
+
82
+ def __eq__(self, other):
83
+ if isinstance(other, SearchNode):
84
+ return self.tool_id == other.tool_id
85
+ return False
86
+
87
+
88
+ class DFSDTSelector(BaseToolSelector):
89
+ """
90
+ DFSDT (Depth-First Search-based Decision Tree) tool selector.
91
+
92
+ This selector implements the core idea from ToolLLM:
93
+ 1. Pre-filter candidates using fast keyword matching (optional)
94
+ 2. Build a search tree with candidate tools as nodes
95
+ 3. Use LLM to score each tool's relevance to the query
96
+ 4. DFS traversal prioritizes high-scoring branches
97
+ 5. Diversity prompting explores alternative tool combinations
98
+
99
+ Key Features:
100
+ - LLM-guided scoring for semantic understanding
101
+ - Tree search for exploring multiple tool combinations
102
+ - Diversity mechanism to avoid local optima
103
+ - Keyword pre-filtering for efficiency with large tool sets
104
+ """
105
+
106
+ def __init__(self, config: DFSDTSelectorConfig, resources: SelectorResources):
107
+ """
108
+ Initialize DFSDT selector.
109
+
110
+ Args:
111
+ config: DFSDT selector configuration
112
+ resources: Shared resources including tools loader
113
+ """
114
+ super().__init__(config, resources)
115
+ self.config: DFSDTSelectorConfig = config
116
+
117
+ # LLM client for scoring
118
+ self._llm_client = None
119
+ self._llm_initialized = False
120
+
121
+ # Keyword selector for pre-filtering
122
+ self._keyword_selector: Optional[KeywordSelector] = None
123
+ if config.use_keyword_prefilter:
124
+ keyword_config = KeywordSelectorConfig(
125
+ name="keyword_prefilter",
126
+ method="bm25",
127
+ top_k=config.prefilter_k,
128
+ )
129
+ self._keyword_selector = KeywordSelector(keyword_config, resources)
130
+
131
+ # Cache for tool metadata
132
+ self._tool_cache: dict[str, dict] = {}
133
+ self._preload_tools()
134
+
135
+ @classmethod
136
+ def from_config(cls, config: SelectorConfig, resources: SelectorResources) -> "DFSDTSelector":
137
+ """Create DFSDT selector from config."""
138
+ if not isinstance(config, DFSDTSelectorConfig):
139
+ config = DFSDTSelectorConfig(
140
+ name=config.name,
141
+ top_k=config.top_k,
142
+ min_score=config.min_score,
143
+ cache_enabled=config.cache_enabled,
144
+ params=config.params,
145
+ )
146
+ return cls(config, resources)
147
+
148
+ def _preload_tools(self) -> None:
149
+ """Preload tool metadata into cache."""
150
+ try:
151
+ tools_loader = self.resources.tools_loader
152
+ for tool in tools_loader.iter_all():
153
+ self._tool_cache[tool.tool_id] = {
154
+ "name": tool.name,
155
+ "description": getattr(tool, "description", "") or "",
156
+ "capabilities": getattr(tool, "capabilities", []) or [],
157
+ "category": getattr(tool, "category", "") or "",
158
+ }
159
+ self.logger.info(f"DFSDT: Preloaded {len(self._tool_cache)} tools")
160
+ except Exception as e:
161
+ self.logger.error(f"Error preloading tools: {e}")
162
+
163
+ def _get_llm_client(self):
164
+ """Lazy initialization of LLM client."""
165
+ if not self._llm_initialized:
166
+ try:
167
+ from sage.llm import UnifiedInferenceClient
168
+
169
+ self._llm_client = UnifiedInferenceClient.create()
170
+ self._llm_initialized = True
171
+ self.logger.info("DFSDT: LLM client initialized")
172
+ except Exception as e:
173
+ self.logger.warning(f"Could not initialize LLM client: {e}")
174
+ self._llm_client = None
175
+ self._llm_initialized = True
176
+
177
+ return self._llm_client
178
+
179
+ def _select_impl(self, query: ToolSelectionQuery, top_k: int) -> list[ToolPrediction]:
180
+ """
181
+ Select tools using DFSDT algorithm.
182
+
183
+ Args:
184
+ query: Tool selection query
185
+ top_k: Number of tools to select
186
+
187
+ Returns:
188
+ List of tool predictions with scores
189
+ """
190
+ # Step 1: Get candidate tools
191
+ if query.candidate_tools:
192
+ candidate_ids = set(query.candidate_tools)
193
+ else:
194
+ candidate_ids = set(self._tool_cache.keys())
195
+
196
+ # Step 2: Pre-filter using keyword matching if enabled
197
+ if self._keyword_selector and len(candidate_ids) > self.config.prefilter_k:
198
+ prefilter_results = self._keyword_selector._select_impl(query, self.config.prefilter_k)
199
+ candidate_ids = {p.tool_id for p in prefilter_results}
200
+ self.logger.debug(f"DFSDT: Pre-filtered to {len(candidate_ids)} candidates")
201
+
202
+ # Step 3: Build search tree and run DFSDT
203
+ results = self._dfsdt_search(query, candidate_ids, top_k)
204
+
205
+ return results
206
+
207
+ def _dfsdt_search(
208
+ self, query: ToolSelectionQuery, candidate_ids: set[str], top_k: int
209
+ ) -> list[ToolPrediction]:
210
+ """
211
+ Run DFSDT search algorithm.
212
+
213
+ Args:
214
+ query: Tool selection query
215
+ candidate_ids: Set of candidate tool IDs
216
+ top_k: Number of tools to select
217
+
218
+ Returns:
219
+ List of tool predictions sorted by score
220
+ """
221
+ # Create root node
222
+ root = SearchNode(
223
+ tool_id="__root__",
224
+ tool_name="root",
225
+ tool_description="Search tree root",
226
+ depth=0,
227
+ )
228
+
229
+ # Create child nodes for each candidate
230
+ for tool_id in candidate_ids:
231
+ if tool_id not in self._tool_cache:
232
+ continue
233
+
234
+ tool_info = self._tool_cache[tool_id]
235
+ node = SearchNode(
236
+ tool_id=tool_id,
237
+ tool_name=tool_info["name"],
238
+ tool_description=tool_info["description"],
239
+ depth=1,
240
+ parent=root,
241
+ )
242
+ root.children.append(node)
243
+
244
+ # Score all nodes using LLM or fallback
245
+ scored_nodes: list[SearchNode] = []
246
+ visited_tools: list[str] = []
247
+
248
+ for node in root.children:
249
+ score = self._score_tool(query, node, visited_tools)
250
+ node.score = score
251
+
252
+ if score >= self.config.score_threshold:
253
+ scored_nodes.append(node)
254
+ visited_tools.append(node.tool_name)
255
+
256
+ # Sort by score (DFS prioritizes high scores)
257
+ scored_nodes.sort(key=lambda n: n.score, reverse=True)
258
+
259
+ # Take top-k results
260
+ predictions = []
261
+ for node in scored_nodes[:top_k]:
262
+ predictions.append(
263
+ ToolPrediction(
264
+ tool_id=node.tool_id,
265
+ score=min(node.score, 1.0), # Normalize to [0, 1]
266
+ metadata={
267
+ "method": "dfsdt",
268
+ "tool_name": node.tool_name,
269
+ "depth": node.depth,
270
+ },
271
+ )
272
+ )
273
+
274
+ return predictions
275
+
276
+ def _score_tool(
277
+ self,
278
+ query: ToolSelectionQuery,
279
+ node: SearchNode,
280
+ visited_tools: list[str],
281
+ ) -> float:
282
+ """
283
+ Score a tool's relevance using LLM.
284
+
285
+ Args:
286
+ query: Tool selection query
287
+ node: Search node representing the tool
288
+ visited_tools: List of already scored tool names (for diversity)
289
+
290
+ Returns:
291
+ Relevance score (0-1)
292
+ """
293
+ llm_client = self._get_llm_client()
294
+
295
+ if llm_client is None:
296
+ # Fallback to simple keyword-based scoring
297
+ return self._fallback_score(query, node)
298
+
299
+ try:
300
+ # Build prompt
301
+ tool_info = self._tool_cache.get(node.tool_id, {})
302
+ capabilities = tool_info.get("capabilities", [])
303
+ cap_str = ", ".join(capabilities) if capabilities else "N/A"
304
+
305
+ base_prompt = TOOL_RELEVANCE_PROMPT.format(
306
+ query=query.instruction,
307
+ tool_name=node.tool_name,
308
+ tool_description=node.tool_description or "No description",
309
+ tool_capabilities=cap_str,
310
+ )
311
+
312
+ # Add diversity prompt if enabled and there are visited tools
313
+ if self.config.use_diversity_prompt and visited_tools:
314
+ prompt = DIVERSITY_PROMPT.format(
315
+ previous_tools=", ".join(visited_tools[-3:]), # Last 3 tools
316
+ base_prompt=base_prompt,
317
+ )
318
+ else:
319
+ prompt = base_prompt
320
+
321
+ # Call LLM
322
+ response = llm_client.chat(
323
+ [{"role": "user", "content": prompt}],
324
+ temperature=self.config.temperature,
325
+ )
326
+
327
+ # Parse score from response
328
+ score = self._parse_score(response)
329
+ return score / 10.0 # Normalize to 0-1
330
+
331
+ except Exception as e:
332
+ self.logger.warning(f"LLM scoring failed for {node.tool_id}: {e}")
333
+ return self._fallback_score(query, node)
334
+
335
+ def _fallback_score(self, query: ToolSelectionQuery, node: SearchNode) -> float:
336
+ """
337
+ Fallback scoring using keyword matching when LLM is unavailable.
338
+
339
+ Args:
340
+ query: Tool selection query
341
+ node: Search node representing the tool
342
+
343
+ Returns:
344
+ Relevance score (0-1)
345
+ """
346
+ query_lower = query.instruction.lower()
347
+ tool_text = f"{node.tool_name} {node.tool_description}".lower()
348
+
349
+ # Simple keyword overlap
350
+ query_words = set(query_lower.split())
351
+ tool_words = set(tool_text.split())
352
+
353
+ if not query_words or not tool_words:
354
+ return 0.0
355
+
356
+ overlap = len(query_words & tool_words)
357
+ score = overlap / len(query_words)
358
+
359
+ return min(score, 1.0)
360
+
361
+ def _parse_score(self, response: str) -> float:
362
+ """
363
+ Parse numeric score from LLM response.
364
+
365
+ Args:
366
+ response: LLM response string
367
+
368
+ Returns:
369
+ Parsed score (0-10)
370
+ """
371
+ import re
372
+
373
+ # Try to extract first number from response
374
+ numbers = re.findall(r"(\d+(?:\.\d+)?)", response.strip())
375
+ if numbers:
376
+ score = float(numbers[0])
377
+ return min(max(score, 0.0), 10.0) # Clamp to 0-10
378
+
379
+ # If no number found, try to infer from keywords
380
+ response_lower = response.lower()
381
+ if any(word in response_lower for word in ["perfect", "excellent", "exactly"]):
382
+ return 9.0
383
+ elif any(word in response_lower for word in ["highly", "very relevant", "good"]):
384
+ return 7.0
385
+ elif any(word in response_lower for word in ["moderate", "somewhat", "partial"]):
386
+ return 5.0
387
+ elif any(word in response_lower for word in ["slight", "minimal", "limited"]):
388
+ return 3.0
389
+ else:
390
+ return 1.0
391
+
392
+ def get_stats(self) -> dict:
393
+ """Get selector statistics."""
394
+ stats = super().get_stats()
395
+ stats.update(
396
+ {
397
+ "llm_initialized": self._llm_initialized,
398
+ "tool_cache_size": len(self._tool_cache),
399
+ "has_keyword_prefilter": self._keyword_selector is not None,
400
+ }
401
+ )
402
+ return stats
@@ -0,0 +1,281 @@
1
+ """
2
+ Embedding-based tool selector.
3
+
4
+ Uses embedding models and vector similarity search for tool selection.
5
+ """
6
+
7
+ import logging
8
+ from typing import Optional
9
+
10
+ import numpy as np
11
+
12
+ from .base import BaseToolSelector, SelectorResources
13
+ from .retriever.vector_index import VectorIndex
14
+ from .schemas import (
15
+ EmbeddingSelectorConfig,
16
+ SelectorConfig,
17
+ ToolPrediction,
18
+ ToolSelectionQuery,
19
+ )
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class EmbeddingSelector(BaseToolSelector):
25
+ """
26
+ Embedding-based tool selector using vector similarity.
27
+
28
+ Uses embedding service to encode queries and tools, then performs
29
+ similarity search using VectorIndex.
30
+
31
+ Optimized for large-scale tool selection (1000+ tools).
32
+ """
33
+
34
+ def __init__(self, config: EmbeddingSelectorConfig, resources: SelectorResources):
35
+ """
36
+ Initialize embedding selector.
37
+
38
+ Args:
39
+ config: Embedding selector configuration
40
+ resources: Shared resources including embedding_client
41
+
42
+ Raises:
43
+ ValueError: If embedding_client is not provided in resources
44
+ """
45
+ super().__init__(config, resources)
46
+ self.config: EmbeddingSelectorConfig = config
47
+
48
+ # Validate embedding client
49
+ if not resources.embedding_client:
50
+ raise ValueError(
51
+ "EmbeddingSelector requires embedding_client in SelectorResources. "
52
+ "Please provide an EmbeddingService instance."
53
+ )
54
+
55
+ self.embedding_client = resources.embedding_client
56
+
57
+ # Initialize vector index
58
+ self._index: Optional[VectorIndex] = None
59
+ self._tool_texts: dict[str, str] = {}
60
+ self._embedding_dimension: Optional[int] = None
61
+
62
+ # Preprocess tools
63
+ self._preprocess_tools()
64
+
65
+ @classmethod
66
+ def from_config(
67
+ cls, config: SelectorConfig, resources: SelectorResources
68
+ ) -> "EmbeddingSelector":
69
+ """Create embedding selector from config."""
70
+ if not isinstance(config, EmbeddingSelectorConfig):
71
+ raise TypeError(f"Expected EmbeddingSelectorConfig, got {type(config).__name__}")
72
+ return cls(config, resources) # type: ignore[arg-type]
73
+
74
+ def _preprocess_tools(self) -> None:
75
+ """Preprocess all tools and build vector index."""
76
+ try:
77
+ tools_loader = self.resources.tools_loader
78
+
79
+ # Collect all tool texts
80
+ tool_ids = []
81
+ tool_texts = []
82
+
83
+ for tool in tools_loader.iter_all():
84
+ text = self._build_tool_text(tool)
85
+ self._tool_texts[tool.tool_id] = text
86
+ tool_ids.append(tool.tool_id)
87
+ tool_texts.append(text)
88
+
89
+ if not tool_texts:
90
+ self.logger.warning("No tools found to preprocess")
91
+ return
92
+
93
+ self.logger.info(f"Embedding {len(tool_texts)} tools...")
94
+
95
+ # Embed all tools in batch
96
+ embeddings = self._embed_texts(tool_texts)
97
+
98
+ # Infer dimension from embeddings
99
+ self._embedding_dimension = embeddings.shape[1]
100
+
101
+ # Build vector index
102
+ self._index = VectorIndex(
103
+ dimension=self._embedding_dimension, metric=self.config.similarity_metric
104
+ )
105
+
106
+ # Add all vectors to index
107
+ self._index.add_batch(
108
+ vector_ids=tool_ids,
109
+ vectors=embeddings,
110
+ metadata=[{"text": text} for text in tool_texts],
111
+ )
112
+
113
+ self.logger.info(
114
+ f"Built vector index with {len(tool_ids)} tools "
115
+ f"(dimension={self._embedding_dimension}, metric={self.config.similarity_metric})"
116
+ )
117
+
118
+ except Exception as e:
119
+ self.logger.error(f"Error preprocessing tools: {e}")
120
+ raise
121
+
122
+ def _build_tool_text(self, tool) -> str:
123
+ """
124
+ Build searchable text from tool metadata.
125
+
126
+ Args:
127
+ tool: Tool object with metadata
128
+
129
+ Returns:
130
+ Concatenated text representation
131
+ """
132
+ parts = [tool.name]
133
+
134
+ if hasattr(tool, "description") and tool.description:
135
+ parts.append(tool.description)
136
+
137
+ if hasattr(tool, "capabilities") and tool.capabilities:
138
+ if isinstance(tool.capabilities, list):
139
+ parts.extend(tool.capabilities)
140
+ else:
141
+ parts.append(str(tool.capabilities))
142
+
143
+ if hasattr(tool, "category") and tool.category:
144
+ parts.append(tool.category)
145
+
146
+ # Include parameter descriptions if available
147
+ if hasattr(tool, "parameters") and tool.parameters:
148
+ if isinstance(tool.parameters, dict):
149
+ for param_name, param_info in tool.parameters.items():
150
+ if isinstance(param_info, dict) and "description" in param_info:
151
+ parts.append(f"{param_name}: {param_info['description']}")
152
+
153
+ return " ".join(parts)
154
+
155
+ def _embed_texts(self, texts: list[str]) -> np.ndarray:
156
+ """
157
+ Embed texts using embedding client.
158
+
159
+ Args:
160
+ texts: List of texts to embed
161
+
162
+ Returns:
163
+ Array of embeddings (shape: N x D)
164
+ """
165
+ try:
166
+ # Use embedding client to embed texts
167
+ # The client should handle batching internally
168
+ embeddings = self.embedding_client.embed(
169
+ texts=texts,
170
+ model=self.config.embedding_model
171
+ if self.config.embedding_model != "default"
172
+ else None,
173
+ batch_size=self.config.batch_size,
174
+ )
175
+
176
+ # Convert to numpy array if not already
177
+ embeddings = np.asarray(embeddings)
178
+
179
+ # Ensure 2D shape
180
+ if embeddings.ndim == 1:
181
+ embeddings = embeddings.reshape(1, -1)
182
+
183
+ return embeddings
184
+
185
+ except Exception as e:
186
+ self.logger.error(f"Error embedding texts: {e}")
187
+ raise
188
+
189
+ def _select_impl(self, query: ToolSelectionQuery, top_k: int) -> list[ToolPrediction]:
190
+ """
191
+ Select tools using embedding similarity.
192
+
193
+ Args:
194
+ query: Tool selection query
195
+ top_k: Number of tools to select
196
+
197
+ Returns:
198
+ List of tool predictions sorted by similarity
199
+ """
200
+ if self._index is None:
201
+ self.logger.error("Vector index not initialized")
202
+ return []
203
+
204
+ try:
205
+ # Embed query
206
+ query_embedding = self._embed_texts([query.instruction])
207
+ query_vector = query_embedding[0] # Get single vector
208
+
209
+ # Filter candidates if specified
210
+ candidate_ids = None
211
+ if query.candidate_tools:
212
+ candidate_ids = set(query.candidate_tools)
213
+ # Only search among candidates that exist in index
214
+ candidate_ids = candidate_ids & set(self._index._ids)
215
+ if not candidate_ids:
216
+ self.logger.warning(f"No valid candidates for query {query.sample_id}")
217
+ return []
218
+
219
+ # Search vector index
220
+ # Returns list of (vector_id, score) tuples
221
+ results = self._index.search(
222
+ query_vector=query_vector,
223
+ top_k=top_k,
224
+ filter_ids=list(candidate_ids) if candidate_ids else None,
225
+ )
226
+
227
+ # Convert to ToolPrediction format
228
+ predictions = []
229
+ for tool_id, score in results:
230
+ # Normalize score to [0, 1] range
231
+ # For euclidean, VectorIndex returns negative distances
232
+ # Convert to similarity: higher is better
233
+ if self.config.similarity_metric == "euclidean":
234
+ # Negative distance -> convert to positive similarity
235
+ score = max(0.0, 1.0 / (1.0 + abs(score)))
236
+ else:
237
+ # Cosine and dot are already similarities
238
+ score = max(0.0, min(1.0, float(score)))
239
+
240
+ predictions.append(
241
+ ToolPrediction(
242
+ tool_id=tool_id,
243
+ score=score,
244
+ metadata={
245
+ "similarity_metric": self.config.similarity_metric,
246
+ "embedding_model": self.config.embedding_model,
247
+ },
248
+ )
249
+ )
250
+
251
+ # Filter by minimum similarity threshold if specified
252
+ if hasattr(self.config, "similarity_threshold"):
253
+ threshold = self.config.similarity_threshold
254
+ predictions = [p for p in predictions if p.score >= threshold]
255
+
256
+ return predictions
257
+
258
+ except Exception as e:
259
+ self.logger.error(f"Error in embedding selection for {query.sample_id}: {e}")
260
+ raise
261
+
262
+ def get_embedding_dimension(self) -> Optional[int]:
263
+ """Get embedding dimension."""
264
+ return self._embedding_dimension
265
+
266
+ def get_index_size(self) -> int:
267
+ """Get number of tools in index."""
268
+ return len(self._index._ids) if self._index else 0
269
+
270
+ def get_stats(self) -> dict:
271
+ """Get selector statistics."""
272
+ stats = super().get_stats()
273
+ stats.update(
274
+ {
275
+ "embedding_dimension": self._embedding_dimension,
276
+ "index_size": self.get_index_size(),
277
+ "similarity_metric": self.config.similarity_metric,
278
+ "embedding_model": self.config.embedding_model,
279
+ }
280
+ )
281
+ return stats