isage-tooluse 0.1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_tooluse-0.1.0.0.dist-info/METADATA +208 -0
- isage_tooluse-0.1.0.0.dist-info/RECORD +14 -0
- isage_tooluse-0.1.0.0.dist-info/WHEEL +5 -0
- isage_tooluse-0.1.0.0.dist-info/licenses/LICENSE +21 -0
- isage_tooluse-0.1.0.0.dist-info/top_level.txt +1 -0
- sage_libs/sage_tooluse/__init__.py +75 -0
- sage_libs/sage_tooluse/base.py +203 -0
- sage_libs/sage_tooluse/dfsdt_selector.py +402 -0
- sage_libs/sage_tooluse/embedding_selector.py +281 -0
- sage_libs/sage_tooluse/gorilla_selector.py +495 -0
- sage_libs/sage_tooluse/hybrid_selector.py +202 -0
- sage_libs/sage_tooluse/keyword_selector.py +270 -0
- sage_libs/sage_tooluse/registry.py +185 -0
- sage_libs/sage_tooluse/schemas.py +196 -0
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Keyword-based tool selector.
|
|
3
|
+
|
|
4
|
+
Implements TF-IDF and token overlap strategies for tool selection.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import re
|
|
9
|
+
from collections import Counter
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from .base import BaseToolSelector, SelectorResources
|
|
14
|
+
from .schemas import KeywordSelectorConfig, ToolPrediction, ToolSelectionQuery
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Common English stopwords
|
|
20
|
+
STOPWORDS = {
|
|
21
|
+
"a",
|
|
22
|
+
"an",
|
|
23
|
+
"and",
|
|
24
|
+
"are",
|
|
25
|
+
"as",
|
|
26
|
+
"at",
|
|
27
|
+
"be",
|
|
28
|
+
"by",
|
|
29
|
+
"for",
|
|
30
|
+
"from",
|
|
31
|
+
"has",
|
|
32
|
+
"he",
|
|
33
|
+
"in",
|
|
34
|
+
"is",
|
|
35
|
+
"it",
|
|
36
|
+
"its",
|
|
37
|
+
"of",
|
|
38
|
+
"on",
|
|
39
|
+
"that",
|
|
40
|
+
"the",
|
|
41
|
+
"to",
|
|
42
|
+
"was",
|
|
43
|
+
"will",
|
|
44
|
+
"with",
|
|
45
|
+
"this",
|
|
46
|
+
"but",
|
|
47
|
+
"they",
|
|
48
|
+
"have",
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class KeywordSelector(BaseToolSelector):
|
|
53
|
+
"""
|
|
54
|
+
Keyword-based tool selector using TF-IDF or token overlap.
|
|
55
|
+
|
|
56
|
+
Fast baseline selector with O(N) complexity.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, config: KeywordSelectorConfig, resources: SelectorResources):
|
|
60
|
+
"""
|
|
61
|
+
Initialize keyword selector.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
config: Keyword selector configuration
|
|
65
|
+
resources: Shared resources
|
|
66
|
+
"""
|
|
67
|
+
super().__init__(config, resources)
|
|
68
|
+
self.config: KeywordSelectorConfig = config
|
|
69
|
+
|
|
70
|
+
# Precompute tool text representations
|
|
71
|
+
self._tool_texts: dict[str, str] = {}
|
|
72
|
+
self._tool_tokens: dict[str, set[str]] = {}
|
|
73
|
+
self._idf_scores: dict[str, float] = {}
|
|
74
|
+
|
|
75
|
+
self._preprocess_tools()
|
|
76
|
+
|
|
77
|
+
@classmethod
|
|
78
|
+
def from_config(
|
|
79
|
+
cls, config: KeywordSelectorConfig, resources: SelectorResources
|
|
80
|
+
) -> "KeywordSelector":
|
|
81
|
+
"""Create keyword selector from config."""
|
|
82
|
+
return cls(config, resources)
|
|
83
|
+
|
|
84
|
+
def _preprocess_tools(self) -> None:
|
|
85
|
+
"""Preprocess all tools and compute IDF scores."""
|
|
86
|
+
try:
|
|
87
|
+
# Get all tools from loader
|
|
88
|
+
tools_loader = self.resources.tools_loader
|
|
89
|
+
|
|
90
|
+
# Build tool texts
|
|
91
|
+
for tool in tools_loader.iter_all():
|
|
92
|
+
text = self._build_tool_text(tool)
|
|
93
|
+
self._tool_texts[tool.tool_id] = text
|
|
94
|
+
self._tool_tokens[tool.tool_id] = self._tokenize(text)
|
|
95
|
+
|
|
96
|
+
# Compute IDF scores (needed for both TF-IDF and BM25)
|
|
97
|
+
if self.config.method in ("tfidf", "bm25"):
|
|
98
|
+
self._compute_idf()
|
|
99
|
+
|
|
100
|
+
self.logger.info(f"Preprocessed {len(self._tool_texts)} tools")
|
|
101
|
+
|
|
102
|
+
except Exception as e:
|
|
103
|
+
self.logger.error(f"Error preprocessing tools: {e}")
|
|
104
|
+
raise
|
|
105
|
+
|
|
106
|
+
def _build_tool_text(self, tool) -> str:
|
|
107
|
+
"""Build searchable text from tool metadata."""
|
|
108
|
+
parts = [tool.name]
|
|
109
|
+
|
|
110
|
+
if hasattr(tool, "description") and tool.description:
|
|
111
|
+
parts.append(tool.description)
|
|
112
|
+
|
|
113
|
+
if hasattr(tool, "capabilities") and tool.capabilities:
|
|
114
|
+
if isinstance(tool.capabilities, list):
|
|
115
|
+
parts.extend(tool.capabilities)
|
|
116
|
+
else:
|
|
117
|
+
parts.append(str(tool.capabilities))
|
|
118
|
+
|
|
119
|
+
if hasattr(tool, "category") and tool.category:
|
|
120
|
+
parts.append(tool.category)
|
|
121
|
+
|
|
122
|
+
return " ".join(parts)
|
|
123
|
+
|
|
124
|
+
def _tokenize(self, text: str) -> set[str]:
|
|
125
|
+
"""Tokenize text into set of tokens."""
|
|
126
|
+
if self.config.lowercase:
|
|
127
|
+
text = text.lower()
|
|
128
|
+
|
|
129
|
+
# Split on non-alphanumeric
|
|
130
|
+
tokens = re.findall(r"\b[a-z0-9_]+\b", text, re.IGNORECASE)
|
|
131
|
+
|
|
132
|
+
# Remove stopwords if enabled
|
|
133
|
+
if self.config.remove_stopwords:
|
|
134
|
+
tokens = [t for t in tokens if t.lower() not in STOPWORDS]
|
|
135
|
+
|
|
136
|
+
# Generate n-grams if needed
|
|
137
|
+
if self.config.ngram_range[1] > 1:
|
|
138
|
+
ngrams = []
|
|
139
|
+
for n in range(self.config.ngram_range[0], self.config.ngram_range[1] + 1):
|
|
140
|
+
for i in range(len(tokens) - n + 1):
|
|
141
|
+
ngrams.append("_".join(tokens[i : i + n]))
|
|
142
|
+
tokens.extend(ngrams)
|
|
143
|
+
|
|
144
|
+
return set(tokens)
|
|
145
|
+
|
|
146
|
+
def _compute_idf(self) -> None:
|
|
147
|
+
"""Compute IDF scores for all tokens."""
|
|
148
|
+
# Count document frequency for each token
|
|
149
|
+
df = Counter()
|
|
150
|
+
total_docs = len(self._tool_tokens)
|
|
151
|
+
|
|
152
|
+
for tokens in self._tool_tokens.values():
|
|
153
|
+
df.update(tokens)
|
|
154
|
+
|
|
155
|
+
# Compute IDF: log(N / df)
|
|
156
|
+
for token, freq in df.items():
|
|
157
|
+
self._idf_scores[token] = np.log(total_docs / freq)
|
|
158
|
+
|
|
159
|
+
def _select_impl(self, query: ToolSelectionQuery, top_k: int) -> list[ToolPrediction]:
|
|
160
|
+
"""
|
|
161
|
+
Select tools using keyword matching.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
query: Tool selection query
|
|
165
|
+
top_k: Number of tools to select
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
List of tool predictions
|
|
169
|
+
"""
|
|
170
|
+
# Tokenize query
|
|
171
|
+
query_tokens = self._tokenize(query.instruction)
|
|
172
|
+
|
|
173
|
+
if not query_tokens:
|
|
174
|
+
self.logger.warning(f"No tokens in query {query.sample_id}")
|
|
175
|
+
return []
|
|
176
|
+
|
|
177
|
+
# Filter candidates
|
|
178
|
+
candidate_ids = (
|
|
179
|
+
set(query.candidate_tools) if query.candidate_tools else set(self._tool_texts.keys())
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Score each candidate
|
|
183
|
+
scores = []
|
|
184
|
+
for tool_id in candidate_ids:
|
|
185
|
+
if tool_id not in self._tool_tokens:
|
|
186
|
+
continue
|
|
187
|
+
|
|
188
|
+
if self.config.method == "tfidf":
|
|
189
|
+
score = self._tfidf_score(query_tokens, tool_id)
|
|
190
|
+
elif self.config.method == "overlap":
|
|
191
|
+
score = self._overlap_score(query_tokens, tool_id)
|
|
192
|
+
elif self.config.method == "bm25":
|
|
193
|
+
score = self._bm25_score(query_tokens, tool_id)
|
|
194
|
+
else:
|
|
195
|
+
raise ValueError(f"Unknown method: {self.config.method}")
|
|
196
|
+
|
|
197
|
+
scores.append((tool_id, score))
|
|
198
|
+
|
|
199
|
+
# Sort by score and take top-k
|
|
200
|
+
scores.sort(key=lambda x: x[1], reverse=True)
|
|
201
|
+
scores = scores[:top_k]
|
|
202
|
+
|
|
203
|
+
# Create predictions
|
|
204
|
+
predictions = [
|
|
205
|
+
ToolPrediction(
|
|
206
|
+
tool_id=tool_id,
|
|
207
|
+
score=min(score, 1.0), # Normalize to [0, 1]
|
|
208
|
+
metadata={"method": self.config.method},
|
|
209
|
+
)
|
|
210
|
+
for tool_id, score in scores
|
|
211
|
+
]
|
|
212
|
+
|
|
213
|
+
return predictions
|
|
214
|
+
|
|
215
|
+
def _tfidf_score(self, query_tokens: set[str], tool_id: str) -> float:
|
|
216
|
+
"""Compute TF-IDF score."""
|
|
217
|
+
tool_tokens = self._tool_tokens[tool_id]
|
|
218
|
+
common = query_tokens & tool_tokens
|
|
219
|
+
|
|
220
|
+
if not common:
|
|
221
|
+
return 0.0
|
|
222
|
+
|
|
223
|
+
# Sum IDF scores for matching tokens
|
|
224
|
+
score = sum(self._idf_scores.get(token, 0.0) for token in common)
|
|
225
|
+
|
|
226
|
+
# Normalize by query length
|
|
227
|
+
score /= len(query_tokens)
|
|
228
|
+
|
|
229
|
+
return score
|
|
230
|
+
|
|
231
|
+
def _overlap_score(self, query_tokens: set[str], tool_id: str) -> float:
|
|
232
|
+
"""Compute token overlap score (Jaccard similarity)."""
|
|
233
|
+
tool_tokens = self._tool_tokens[tool_id]
|
|
234
|
+
|
|
235
|
+
if not query_tokens or not tool_tokens:
|
|
236
|
+
return 0.0
|
|
237
|
+
|
|
238
|
+
intersection = len(query_tokens & tool_tokens)
|
|
239
|
+
union = len(query_tokens | tool_tokens)
|
|
240
|
+
|
|
241
|
+
return intersection / union if union > 0 else 0.0
|
|
242
|
+
|
|
243
|
+
def _bm25_score(self, query_tokens: set[str], tool_id: str) -> float:
|
|
244
|
+
"""Compute BM25 score (simplified)."""
|
|
245
|
+
tool_tokens = self._tool_tokens[tool_id]
|
|
246
|
+
common = query_tokens & tool_tokens
|
|
247
|
+
|
|
248
|
+
if not common:
|
|
249
|
+
return 0.0
|
|
250
|
+
|
|
251
|
+
# BM25 parameters
|
|
252
|
+
k1 = 1.5
|
|
253
|
+
b = 0.75
|
|
254
|
+
|
|
255
|
+
# Average document length
|
|
256
|
+
avg_len = np.mean([len(tokens) for tokens in self._tool_tokens.values()])
|
|
257
|
+
doc_len = len(tool_tokens)
|
|
258
|
+
|
|
259
|
+
score = 0.0
|
|
260
|
+
for token in common:
|
|
261
|
+
idf = self._idf_scores.get(token, 0.0)
|
|
262
|
+
tf = 1 # Binary TF
|
|
263
|
+
|
|
264
|
+
# BM25 formula
|
|
265
|
+
numerator = tf * (k1 + 1)
|
|
266
|
+
denominator = tf + k1 * (1 - b + b * doc_len / avg_len)
|
|
267
|
+
|
|
268
|
+
score += idf * (numerator / denominator)
|
|
269
|
+
|
|
270
|
+
return score
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Registry for tool selector strategies.
|
|
3
|
+
|
|
4
|
+
Provides registration, lookup, and factory creation of selectors.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Any, Optional
|
|
9
|
+
|
|
10
|
+
from .base import BaseToolSelector, SelectorResources
|
|
11
|
+
from .schemas import SelectorConfig, create_selector_config
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SelectorRegistry:
|
|
17
|
+
"""
|
|
18
|
+
Registry for tool selector strategies.
|
|
19
|
+
|
|
20
|
+
Supports registration, lookup, and factory creation of selectors.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
_instance: Optional["SelectorRegistry"] = None
|
|
24
|
+
_selectors: dict[str, type[BaseToolSelector]] = {}
|
|
25
|
+
|
|
26
|
+
def __init__(self):
|
|
27
|
+
"""Initialize registry."""
|
|
28
|
+
self._selectors = {}
|
|
29
|
+
self._instances: dict[str, BaseToolSelector] = {}
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def get_instance(cls) -> "SelectorRegistry":
|
|
33
|
+
"""Get singleton registry instance."""
|
|
34
|
+
if cls._instance is None:
|
|
35
|
+
cls._instance = cls()
|
|
36
|
+
return cls._instance
|
|
37
|
+
|
|
38
|
+
def register(self, name: str, selector_class: type[BaseToolSelector]) -> None:
|
|
39
|
+
"""
|
|
40
|
+
Register a selector class.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
name: Selector strategy name
|
|
44
|
+
selector_class: Selector class to register
|
|
45
|
+
"""
|
|
46
|
+
if name in self._selectors:
|
|
47
|
+
logger.warning(f"Overwriting existing selector: {name}")
|
|
48
|
+
|
|
49
|
+
self._selectors[name] = selector_class
|
|
50
|
+
logger.info(f"Registered selector: {name}")
|
|
51
|
+
|
|
52
|
+
def get_class(self, name: str) -> Optional[type[BaseToolSelector]]:
|
|
53
|
+
"""
|
|
54
|
+
Get selector class by name.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
name: Selector strategy name
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Selector class or None if not found
|
|
61
|
+
"""
|
|
62
|
+
return self._selectors.get(name)
|
|
63
|
+
|
|
64
|
+
def get(
|
|
65
|
+
self,
|
|
66
|
+
name: str,
|
|
67
|
+
config: Optional[SelectorConfig] = None,
|
|
68
|
+
resources: Optional[SelectorResources] = None,
|
|
69
|
+
cache: bool = True,
|
|
70
|
+
) -> BaseToolSelector:
|
|
71
|
+
"""
|
|
72
|
+
Get or create selector instance.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
name: Selector strategy name
|
|
76
|
+
config: Optional selector configuration
|
|
77
|
+
resources: Optional resources (required for new instances)
|
|
78
|
+
cache: Whether to cache and reuse instances
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Selector instance
|
|
82
|
+
|
|
83
|
+
Raises:
|
|
84
|
+
ValueError: If selector not registered or resources missing
|
|
85
|
+
"""
|
|
86
|
+
# Check cache
|
|
87
|
+
if cache and name in self._instances:
|
|
88
|
+
return self._instances[name]
|
|
89
|
+
|
|
90
|
+
# Get class
|
|
91
|
+
selector_class = self.get_class(name)
|
|
92
|
+
if selector_class is None:
|
|
93
|
+
raise ValueError(f"Unknown selector: {name}. Available: {list(self._selectors.keys())}")
|
|
94
|
+
|
|
95
|
+
# Create config if needed
|
|
96
|
+
if config is None:
|
|
97
|
+
config = create_selector_config({"name": name})
|
|
98
|
+
|
|
99
|
+
# Validate resources
|
|
100
|
+
if resources is None:
|
|
101
|
+
raise ValueError(f"Resources required to create selector: {name}")
|
|
102
|
+
|
|
103
|
+
# Create instance
|
|
104
|
+
instance = selector_class.from_config(config, resources)
|
|
105
|
+
|
|
106
|
+
# Cache if requested
|
|
107
|
+
if cache:
|
|
108
|
+
self._instances[name] = instance
|
|
109
|
+
|
|
110
|
+
return instance
|
|
111
|
+
|
|
112
|
+
def create_from_config(
|
|
113
|
+
self, config_dict: dict[str, Any], resources: SelectorResources
|
|
114
|
+
) -> BaseToolSelector:
|
|
115
|
+
"""
|
|
116
|
+
Create selector from configuration dictionary.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
config_dict: Configuration dictionary
|
|
120
|
+
resources: Shared resources
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Initialized selector instance
|
|
124
|
+
"""
|
|
125
|
+
config = create_selector_config(config_dict)
|
|
126
|
+
return self.get(config.name, config, resources, cache=False)
|
|
127
|
+
|
|
128
|
+
def list_selectors(self) -> list:
|
|
129
|
+
"""List all registered selector names."""
|
|
130
|
+
return list(self._selectors.keys())
|
|
131
|
+
|
|
132
|
+
def clear_cache(self) -> None:
|
|
133
|
+
"""Clear cached selector instances."""
|
|
134
|
+
self._instances.clear()
|
|
135
|
+
logger.info("Cleared selector instance cache")
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
# Global registry instance
|
|
139
|
+
_registry = SelectorRegistry.get_instance()
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def register_selector(name: str, selector_class: type[BaseToolSelector]) -> None:
|
|
143
|
+
"""
|
|
144
|
+
Register a selector class globally.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
name: Selector strategy name
|
|
148
|
+
selector_class: Selector class to register
|
|
149
|
+
"""
|
|
150
|
+
_registry.register(name, selector_class)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def get_selector(
|
|
154
|
+
name: str,
|
|
155
|
+
config: Optional[SelectorConfig] = None,
|
|
156
|
+
resources: Optional[SelectorResources] = None,
|
|
157
|
+
) -> BaseToolSelector:
|
|
158
|
+
"""
|
|
159
|
+
Get selector instance from global registry.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
name: Selector strategy name
|
|
163
|
+
config: Optional selector configuration
|
|
164
|
+
resources: Optional resources
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
Selector instance
|
|
168
|
+
"""
|
|
169
|
+
return _registry.get(name, config, resources)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def create_selector_from_config(
|
|
173
|
+
config_dict: dict[str, Any], resources: SelectorResources
|
|
174
|
+
) -> BaseToolSelector:
|
|
175
|
+
"""
|
|
176
|
+
Create selector from config dictionary using global registry.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
config_dict: Configuration dictionary
|
|
180
|
+
resources: Shared resources
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Initialized selector instance
|
|
184
|
+
"""
|
|
185
|
+
return _registry.create_from_config(config_dict, resources)
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data schemas for tool selection.
|
|
3
|
+
|
|
4
|
+
Defines Pydantic models for queries, predictions, and configurations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Any, Optional
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, Field
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ToolSelectionQuery(BaseModel):
|
|
13
|
+
"""Query for tool selection."""
|
|
14
|
+
|
|
15
|
+
sample_id: str = Field(..., description="Unique identifier for the query")
|
|
16
|
+
instruction: str = Field(..., description="User instruction or task description")
|
|
17
|
+
context: dict[str, Any] = Field(default_factory=dict, description="Additional context")
|
|
18
|
+
candidate_tools: list[str] = Field(..., description="List of candidate tool IDs")
|
|
19
|
+
metadata: dict[str, Any] = Field(default_factory=dict, description="Optional metadata")
|
|
20
|
+
|
|
21
|
+
class Config:
|
|
22
|
+
extra = "allow"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ToolPrediction(BaseModel):
|
|
26
|
+
"""Prediction result for a single tool."""
|
|
27
|
+
|
|
28
|
+
tool_id: str = Field(..., description="Tool identifier")
|
|
29
|
+
score: float = Field(..., ge=0.0, le=1.0, description="Relevance score (0-1)")
|
|
30
|
+
explanation: Optional[str] = Field(default=None, description="Optional explanation")
|
|
31
|
+
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
|
32
|
+
|
|
33
|
+
class Config:
|
|
34
|
+
frozen = True # Make immutable for caching
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class SelectorConfig(BaseModel):
|
|
38
|
+
"""Base configuration for tool selectors."""
|
|
39
|
+
|
|
40
|
+
name: str = Field(..., description="Selector strategy name")
|
|
41
|
+
top_k: int = Field(default=5, ge=1, description="Number of tools to select")
|
|
42
|
+
min_score: float = Field(default=0.0, ge=0.0, le=1.0, description="Minimum score threshold")
|
|
43
|
+
cache_enabled: bool = Field(default=True, description="Enable result caching")
|
|
44
|
+
params: dict[str, Any] = Field(default_factory=dict, description="Strategy-specific parameters")
|
|
45
|
+
|
|
46
|
+
class Config:
|
|
47
|
+
extra = "allow"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class KeywordSelectorConfig(SelectorConfig):
|
|
51
|
+
"""Configuration for keyword-based selector."""
|
|
52
|
+
|
|
53
|
+
name: str = "keyword"
|
|
54
|
+
method: str = Field(
|
|
55
|
+
default="tfidf", description="Keyword matching method: tfidf, overlap, bm25"
|
|
56
|
+
)
|
|
57
|
+
lowercase: bool = Field(default=True, description="Convert to lowercase")
|
|
58
|
+
remove_stopwords: bool = Field(default=True, description="Remove stopwords")
|
|
59
|
+
ngram_range: tuple = Field(default=(1, 2), description="N-gram range for features")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class EmbeddingSelectorConfig(SelectorConfig):
|
|
63
|
+
"""Configuration for embedding-based selector."""
|
|
64
|
+
|
|
65
|
+
name: str = "embedding"
|
|
66
|
+
embedding_model: str = Field(default="default", description="Embedding model identifier")
|
|
67
|
+
similarity_metric: str = Field(
|
|
68
|
+
default="cosine", description="Similarity metric: cosine, dot, euclidean"
|
|
69
|
+
)
|
|
70
|
+
use_cache: bool = Field(default=True, description="Cache embedding vectors")
|
|
71
|
+
batch_size: int = Field(default=32, ge=1, description="Batch size for embedding")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class TwoStageSelectorConfig(SelectorConfig):
|
|
75
|
+
"""Configuration for two-stage selector."""
|
|
76
|
+
|
|
77
|
+
name: str = "two_stage"
|
|
78
|
+
coarse_k: int = Field(
|
|
79
|
+
default=20, ge=1, description="Number of candidates from coarse retrieval"
|
|
80
|
+
)
|
|
81
|
+
coarse_selector: str = Field(default="keyword", description="Coarse retrieval selector")
|
|
82
|
+
rerank_selector: str = Field(default="embedding", description="Reranking selector")
|
|
83
|
+
fusion_weight: float = Field(default=0.5, ge=0.0, le=1.0, description="Weight for score fusion")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class AdaptiveSelectorConfig(SelectorConfig):
|
|
87
|
+
"""Configuration for adaptive selector."""
|
|
88
|
+
|
|
89
|
+
name: str = "adaptive"
|
|
90
|
+
strategies: list[str] = Field(
|
|
91
|
+
default_factory=lambda: ["keyword", "embedding"], description="List of strategies"
|
|
92
|
+
)
|
|
93
|
+
selection_method: str = Field(
|
|
94
|
+
default="bandit", description="Selection method: bandit, ensemble, threshold"
|
|
95
|
+
)
|
|
96
|
+
exploration_rate: float = Field(
|
|
97
|
+
default=0.1, ge=0.0, le=1.0, description="Exploration rate for bandit"
|
|
98
|
+
)
|
|
99
|
+
update_interval: int = Field(default=100, ge=1, description="Update interval for adaptation")
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class DFSDTSelectorConfig(SelectorConfig):
|
|
103
|
+
"""
|
|
104
|
+
Configuration for DFSDT (Depth-First Search-based Decision Tree) selector.
|
|
105
|
+
|
|
106
|
+
Based on ToolLLM paper (Qin et al., 2023):
|
|
107
|
+
"ToolLLM: Facilitating Large Language Models to Master 16000+ Real-world APIs"
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
name: str = "dfsdt"
|
|
111
|
+
max_depth: int = Field(default=3, ge=1, le=10, description="Maximum search depth")
|
|
112
|
+
beam_width: int = Field(default=5, ge=1, le=20, description="Number of candidates per level")
|
|
113
|
+
llm_model: str = Field(
|
|
114
|
+
default="auto", description="LLM model for scoring (auto uses UnifiedInferenceClient)"
|
|
115
|
+
)
|
|
116
|
+
temperature: float = Field(default=0.1, ge=0.0, le=2.0, description="LLM sampling temperature")
|
|
117
|
+
use_diversity_prompt: bool = Field(
|
|
118
|
+
default=True, description="Use diversity prompting for exploration"
|
|
119
|
+
)
|
|
120
|
+
score_threshold: float = Field(
|
|
121
|
+
default=0.3, ge=0.0, le=1.0, description="Minimum score threshold for pruning"
|
|
122
|
+
)
|
|
123
|
+
use_keyword_prefilter: bool = Field(
|
|
124
|
+
default=True, description="Use keyword matching to pre-filter candidates"
|
|
125
|
+
)
|
|
126
|
+
prefilter_k: int = Field(
|
|
127
|
+
default=20, ge=5, le=100, description="Number of candidates after pre-filtering"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class GorillaSelectorConfig(SelectorConfig):
|
|
132
|
+
"""
|
|
133
|
+
Configuration for Gorilla-style retrieval-augmented selector.
|
|
134
|
+
|
|
135
|
+
Based on Gorilla paper (Patil et al., 2023):
|
|
136
|
+
"Gorilla: Large Language Model Connected with Massive APIs"
|
|
137
|
+
|
|
138
|
+
Two-stage approach: embedding retrieval + LLM selection.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
name: str = "gorilla"
|
|
142
|
+
top_k_retrieve: int = Field(
|
|
143
|
+
default=20, ge=1, description="Number of tools to retrieve in first stage"
|
|
144
|
+
)
|
|
145
|
+
top_k_select: int = Field(
|
|
146
|
+
default=5, ge=1, description="Number of tools to select in final output"
|
|
147
|
+
)
|
|
148
|
+
embedding_model: str = Field(default="default", description="Embedding model for retrieval")
|
|
149
|
+
llm_model: str = Field(
|
|
150
|
+
default="auto", description="LLM model for selection (auto uses UnifiedInferenceClient)"
|
|
151
|
+
)
|
|
152
|
+
similarity_metric: str = Field(
|
|
153
|
+
default="cosine", description="Similarity metric: cosine, dot, euclidean"
|
|
154
|
+
)
|
|
155
|
+
temperature: float = Field(
|
|
156
|
+
default=0.1, ge=0.0, le=2.0, description="LLM temperature for selection"
|
|
157
|
+
)
|
|
158
|
+
use_detailed_docs: bool = Field(
|
|
159
|
+
default=True, description="Include detailed parameter docs in context"
|
|
160
|
+
)
|
|
161
|
+
max_context_tools: int = Field(
|
|
162
|
+
default=15, ge=1, description="Max tools to include in LLM context"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
# Config type registry
|
|
167
|
+
CONFIG_TYPES = {
|
|
168
|
+
"keyword": KeywordSelectorConfig,
|
|
169
|
+
"embedding": EmbeddingSelectorConfig,
|
|
170
|
+
"two_stage": TwoStageSelectorConfig,
|
|
171
|
+
"adaptive": AdaptiveSelectorConfig,
|
|
172
|
+
"dfsdt": DFSDTSelectorConfig,
|
|
173
|
+
"gorilla": GorillaSelectorConfig,
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def create_selector_config(config_dict: dict[str, Any]) -> SelectorConfig:
|
|
178
|
+
"""
|
|
179
|
+
Create appropriate selector config from dictionary.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
config_dict: Configuration dictionary
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Typed SelectorConfig subclass instance
|
|
186
|
+
|
|
187
|
+
Raises:
|
|
188
|
+
ValueError: If selector name not recognized
|
|
189
|
+
"""
|
|
190
|
+
selector_name = config_dict.get("name", "keyword")
|
|
191
|
+
|
|
192
|
+
if selector_name not in CONFIG_TYPES:
|
|
193
|
+
raise ValueError(f"Unknown selector type: {selector_name}")
|
|
194
|
+
|
|
195
|
+
config_class = CONFIG_TYPES[selector_name]
|
|
196
|
+
return config_class(**config_dict)
|