isage-tooluse 0.1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,270 @@
1
+ """
2
+ Keyword-based tool selector.
3
+
4
+ Implements TF-IDF and token overlap strategies for tool selection.
5
+ """
6
+
7
+ import logging
8
+ import re
9
+ from collections import Counter
10
+
11
+ import numpy as np
12
+
13
+ from .base import BaseToolSelector, SelectorResources
14
+ from .schemas import KeywordSelectorConfig, ToolPrediction, ToolSelectionQuery
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ # Common English stopwords
20
+ STOPWORDS = {
21
+ "a",
22
+ "an",
23
+ "and",
24
+ "are",
25
+ "as",
26
+ "at",
27
+ "be",
28
+ "by",
29
+ "for",
30
+ "from",
31
+ "has",
32
+ "he",
33
+ "in",
34
+ "is",
35
+ "it",
36
+ "its",
37
+ "of",
38
+ "on",
39
+ "that",
40
+ "the",
41
+ "to",
42
+ "was",
43
+ "will",
44
+ "with",
45
+ "this",
46
+ "but",
47
+ "they",
48
+ "have",
49
+ }
50
+
51
+
52
+ class KeywordSelector(BaseToolSelector):
53
+ """
54
+ Keyword-based tool selector using TF-IDF or token overlap.
55
+
56
+ Fast baseline selector with O(N) complexity.
57
+ """
58
+
59
+ def __init__(self, config: KeywordSelectorConfig, resources: SelectorResources):
60
+ """
61
+ Initialize keyword selector.
62
+
63
+ Args:
64
+ config: Keyword selector configuration
65
+ resources: Shared resources
66
+ """
67
+ super().__init__(config, resources)
68
+ self.config: KeywordSelectorConfig = config
69
+
70
+ # Precompute tool text representations
71
+ self._tool_texts: dict[str, str] = {}
72
+ self._tool_tokens: dict[str, set[str]] = {}
73
+ self._idf_scores: dict[str, float] = {}
74
+
75
+ self._preprocess_tools()
76
+
77
+ @classmethod
78
+ def from_config(
79
+ cls, config: KeywordSelectorConfig, resources: SelectorResources
80
+ ) -> "KeywordSelector":
81
+ """Create keyword selector from config."""
82
+ return cls(config, resources)
83
+
84
+ def _preprocess_tools(self) -> None:
85
+ """Preprocess all tools and compute IDF scores."""
86
+ try:
87
+ # Get all tools from loader
88
+ tools_loader = self.resources.tools_loader
89
+
90
+ # Build tool texts
91
+ for tool in tools_loader.iter_all():
92
+ text = self._build_tool_text(tool)
93
+ self._tool_texts[tool.tool_id] = text
94
+ self._tool_tokens[tool.tool_id] = self._tokenize(text)
95
+
96
+ # Compute IDF scores (needed for both TF-IDF and BM25)
97
+ if self.config.method in ("tfidf", "bm25"):
98
+ self._compute_idf()
99
+
100
+ self.logger.info(f"Preprocessed {len(self._tool_texts)} tools")
101
+
102
+ except Exception as e:
103
+ self.logger.error(f"Error preprocessing tools: {e}")
104
+ raise
105
+
106
+ def _build_tool_text(self, tool) -> str:
107
+ """Build searchable text from tool metadata."""
108
+ parts = [tool.name]
109
+
110
+ if hasattr(tool, "description") and tool.description:
111
+ parts.append(tool.description)
112
+
113
+ if hasattr(tool, "capabilities") and tool.capabilities:
114
+ if isinstance(tool.capabilities, list):
115
+ parts.extend(tool.capabilities)
116
+ else:
117
+ parts.append(str(tool.capabilities))
118
+
119
+ if hasattr(tool, "category") and tool.category:
120
+ parts.append(tool.category)
121
+
122
+ return " ".join(parts)
123
+
124
+ def _tokenize(self, text: str) -> set[str]:
125
+ """Tokenize text into set of tokens."""
126
+ if self.config.lowercase:
127
+ text = text.lower()
128
+
129
+ # Split on non-alphanumeric
130
+ tokens = re.findall(r"\b[a-z0-9_]+\b", text, re.IGNORECASE)
131
+
132
+ # Remove stopwords if enabled
133
+ if self.config.remove_stopwords:
134
+ tokens = [t for t in tokens if t.lower() not in STOPWORDS]
135
+
136
+ # Generate n-grams if needed
137
+ if self.config.ngram_range[1] > 1:
138
+ ngrams = []
139
+ for n in range(self.config.ngram_range[0], self.config.ngram_range[1] + 1):
140
+ for i in range(len(tokens) - n + 1):
141
+ ngrams.append("_".join(tokens[i : i + n]))
142
+ tokens.extend(ngrams)
143
+
144
+ return set(tokens)
145
+
146
+ def _compute_idf(self) -> None:
147
+ """Compute IDF scores for all tokens."""
148
+ # Count document frequency for each token
149
+ df = Counter()
150
+ total_docs = len(self._tool_tokens)
151
+
152
+ for tokens in self._tool_tokens.values():
153
+ df.update(tokens)
154
+
155
+ # Compute IDF: log(N / df)
156
+ for token, freq in df.items():
157
+ self._idf_scores[token] = np.log(total_docs / freq)
158
+
159
+ def _select_impl(self, query: ToolSelectionQuery, top_k: int) -> list[ToolPrediction]:
160
+ """
161
+ Select tools using keyword matching.
162
+
163
+ Args:
164
+ query: Tool selection query
165
+ top_k: Number of tools to select
166
+
167
+ Returns:
168
+ List of tool predictions
169
+ """
170
+ # Tokenize query
171
+ query_tokens = self._tokenize(query.instruction)
172
+
173
+ if not query_tokens:
174
+ self.logger.warning(f"No tokens in query {query.sample_id}")
175
+ return []
176
+
177
+ # Filter candidates
178
+ candidate_ids = (
179
+ set(query.candidate_tools) if query.candidate_tools else set(self._tool_texts.keys())
180
+ )
181
+
182
+ # Score each candidate
183
+ scores = []
184
+ for tool_id in candidate_ids:
185
+ if tool_id not in self._tool_tokens:
186
+ continue
187
+
188
+ if self.config.method == "tfidf":
189
+ score = self._tfidf_score(query_tokens, tool_id)
190
+ elif self.config.method == "overlap":
191
+ score = self._overlap_score(query_tokens, tool_id)
192
+ elif self.config.method == "bm25":
193
+ score = self._bm25_score(query_tokens, tool_id)
194
+ else:
195
+ raise ValueError(f"Unknown method: {self.config.method}")
196
+
197
+ scores.append((tool_id, score))
198
+
199
+ # Sort by score and take top-k
200
+ scores.sort(key=lambda x: x[1], reverse=True)
201
+ scores = scores[:top_k]
202
+
203
+ # Create predictions
204
+ predictions = [
205
+ ToolPrediction(
206
+ tool_id=tool_id,
207
+ score=min(score, 1.0), # Normalize to [0, 1]
208
+ metadata={"method": self.config.method},
209
+ )
210
+ for tool_id, score in scores
211
+ ]
212
+
213
+ return predictions
214
+
215
+ def _tfidf_score(self, query_tokens: set[str], tool_id: str) -> float:
216
+ """Compute TF-IDF score."""
217
+ tool_tokens = self._tool_tokens[tool_id]
218
+ common = query_tokens & tool_tokens
219
+
220
+ if not common:
221
+ return 0.0
222
+
223
+ # Sum IDF scores for matching tokens
224
+ score = sum(self._idf_scores.get(token, 0.0) for token in common)
225
+
226
+ # Normalize by query length
227
+ score /= len(query_tokens)
228
+
229
+ return score
230
+
231
+ def _overlap_score(self, query_tokens: set[str], tool_id: str) -> float:
232
+ """Compute token overlap score (Jaccard similarity)."""
233
+ tool_tokens = self._tool_tokens[tool_id]
234
+
235
+ if not query_tokens or not tool_tokens:
236
+ return 0.0
237
+
238
+ intersection = len(query_tokens & tool_tokens)
239
+ union = len(query_tokens | tool_tokens)
240
+
241
+ return intersection / union if union > 0 else 0.0
242
+
243
+ def _bm25_score(self, query_tokens: set[str], tool_id: str) -> float:
244
+ """Compute BM25 score (simplified)."""
245
+ tool_tokens = self._tool_tokens[tool_id]
246
+ common = query_tokens & tool_tokens
247
+
248
+ if not common:
249
+ return 0.0
250
+
251
+ # BM25 parameters
252
+ k1 = 1.5
253
+ b = 0.75
254
+
255
+ # Average document length
256
+ avg_len = np.mean([len(tokens) for tokens in self._tool_tokens.values()])
257
+ doc_len = len(tool_tokens)
258
+
259
+ score = 0.0
260
+ for token in common:
261
+ idf = self._idf_scores.get(token, 0.0)
262
+ tf = 1 # Binary TF
263
+
264
+ # BM25 formula
265
+ numerator = tf * (k1 + 1)
266
+ denominator = tf + k1 * (1 - b + b * doc_len / avg_len)
267
+
268
+ score += idf * (numerator / denominator)
269
+
270
+ return score
@@ -0,0 +1,185 @@
1
+ """
2
+ Registry for tool selector strategies.
3
+
4
+ Provides registration, lookup, and factory creation of selectors.
5
+ """
6
+
7
+ import logging
8
+ from typing import Any, Optional
9
+
10
+ from .base import BaseToolSelector, SelectorResources
11
+ from .schemas import SelectorConfig, create_selector_config
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class SelectorRegistry:
17
+ """
18
+ Registry for tool selector strategies.
19
+
20
+ Supports registration, lookup, and factory creation of selectors.
21
+ """
22
+
23
+ _instance: Optional["SelectorRegistry"] = None
24
+ _selectors: dict[str, type[BaseToolSelector]] = {}
25
+
26
+ def __init__(self):
27
+ """Initialize registry."""
28
+ self._selectors = {}
29
+ self._instances: dict[str, BaseToolSelector] = {}
30
+
31
+ @classmethod
32
+ def get_instance(cls) -> "SelectorRegistry":
33
+ """Get singleton registry instance."""
34
+ if cls._instance is None:
35
+ cls._instance = cls()
36
+ return cls._instance
37
+
38
+ def register(self, name: str, selector_class: type[BaseToolSelector]) -> None:
39
+ """
40
+ Register a selector class.
41
+
42
+ Args:
43
+ name: Selector strategy name
44
+ selector_class: Selector class to register
45
+ """
46
+ if name in self._selectors:
47
+ logger.warning(f"Overwriting existing selector: {name}")
48
+
49
+ self._selectors[name] = selector_class
50
+ logger.info(f"Registered selector: {name}")
51
+
52
+ def get_class(self, name: str) -> Optional[type[BaseToolSelector]]:
53
+ """
54
+ Get selector class by name.
55
+
56
+ Args:
57
+ name: Selector strategy name
58
+
59
+ Returns:
60
+ Selector class or None if not found
61
+ """
62
+ return self._selectors.get(name)
63
+
64
+ def get(
65
+ self,
66
+ name: str,
67
+ config: Optional[SelectorConfig] = None,
68
+ resources: Optional[SelectorResources] = None,
69
+ cache: bool = True,
70
+ ) -> BaseToolSelector:
71
+ """
72
+ Get or create selector instance.
73
+
74
+ Args:
75
+ name: Selector strategy name
76
+ config: Optional selector configuration
77
+ resources: Optional resources (required for new instances)
78
+ cache: Whether to cache and reuse instances
79
+
80
+ Returns:
81
+ Selector instance
82
+
83
+ Raises:
84
+ ValueError: If selector not registered or resources missing
85
+ """
86
+ # Check cache
87
+ if cache and name in self._instances:
88
+ return self._instances[name]
89
+
90
+ # Get class
91
+ selector_class = self.get_class(name)
92
+ if selector_class is None:
93
+ raise ValueError(f"Unknown selector: {name}. Available: {list(self._selectors.keys())}")
94
+
95
+ # Create config if needed
96
+ if config is None:
97
+ config = create_selector_config({"name": name})
98
+
99
+ # Validate resources
100
+ if resources is None:
101
+ raise ValueError(f"Resources required to create selector: {name}")
102
+
103
+ # Create instance
104
+ instance = selector_class.from_config(config, resources)
105
+
106
+ # Cache if requested
107
+ if cache:
108
+ self._instances[name] = instance
109
+
110
+ return instance
111
+
112
+ def create_from_config(
113
+ self, config_dict: dict[str, Any], resources: SelectorResources
114
+ ) -> BaseToolSelector:
115
+ """
116
+ Create selector from configuration dictionary.
117
+
118
+ Args:
119
+ config_dict: Configuration dictionary
120
+ resources: Shared resources
121
+
122
+ Returns:
123
+ Initialized selector instance
124
+ """
125
+ config = create_selector_config(config_dict)
126
+ return self.get(config.name, config, resources, cache=False)
127
+
128
+ def list_selectors(self) -> list:
129
+ """List all registered selector names."""
130
+ return list(self._selectors.keys())
131
+
132
+ def clear_cache(self) -> None:
133
+ """Clear cached selector instances."""
134
+ self._instances.clear()
135
+ logger.info("Cleared selector instance cache")
136
+
137
+
138
+ # Global registry instance
139
+ _registry = SelectorRegistry.get_instance()
140
+
141
+
142
+ def register_selector(name: str, selector_class: type[BaseToolSelector]) -> None:
143
+ """
144
+ Register a selector class globally.
145
+
146
+ Args:
147
+ name: Selector strategy name
148
+ selector_class: Selector class to register
149
+ """
150
+ _registry.register(name, selector_class)
151
+
152
+
153
+ def get_selector(
154
+ name: str,
155
+ config: Optional[SelectorConfig] = None,
156
+ resources: Optional[SelectorResources] = None,
157
+ ) -> BaseToolSelector:
158
+ """
159
+ Get selector instance from global registry.
160
+
161
+ Args:
162
+ name: Selector strategy name
163
+ config: Optional selector configuration
164
+ resources: Optional resources
165
+
166
+ Returns:
167
+ Selector instance
168
+ """
169
+ return _registry.get(name, config, resources)
170
+
171
+
172
+ def create_selector_from_config(
173
+ config_dict: dict[str, Any], resources: SelectorResources
174
+ ) -> BaseToolSelector:
175
+ """
176
+ Create selector from config dictionary using global registry.
177
+
178
+ Args:
179
+ config_dict: Configuration dictionary
180
+ resources: Shared resources
181
+
182
+ Returns:
183
+ Initialized selector instance
184
+ """
185
+ return _registry.create_from_config(config_dict, resources)
@@ -0,0 +1,196 @@
1
+ """
2
+ Data schemas for tool selection.
3
+
4
+ Defines Pydantic models for queries, predictions, and configurations.
5
+ """
6
+
7
+ from typing import Any, Optional
8
+
9
+ from pydantic import BaseModel, Field
10
+
11
+
12
+ class ToolSelectionQuery(BaseModel):
13
+ """Query for tool selection."""
14
+
15
+ sample_id: str = Field(..., description="Unique identifier for the query")
16
+ instruction: str = Field(..., description="User instruction or task description")
17
+ context: dict[str, Any] = Field(default_factory=dict, description="Additional context")
18
+ candidate_tools: list[str] = Field(..., description="List of candidate tool IDs")
19
+ metadata: dict[str, Any] = Field(default_factory=dict, description="Optional metadata")
20
+
21
+ class Config:
22
+ extra = "allow"
23
+
24
+
25
+ class ToolPrediction(BaseModel):
26
+ """Prediction result for a single tool."""
27
+
28
+ tool_id: str = Field(..., description="Tool identifier")
29
+ score: float = Field(..., ge=0.0, le=1.0, description="Relevance score (0-1)")
30
+ explanation: Optional[str] = Field(default=None, description="Optional explanation")
31
+ metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
32
+
33
+ class Config:
34
+ frozen = True # Make immutable for caching
35
+
36
+
37
+ class SelectorConfig(BaseModel):
38
+ """Base configuration for tool selectors."""
39
+
40
+ name: str = Field(..., description="Selector strategy name")
41
+ top_k: int = Field(default=5, ge=1, description="Number of tools to select")
42
+ min_score: float = Field(default=0.0, ge=0.0, le=1.0, description="Minimum score threshold")
43
+ cache_enabled: bool = Field(default=True, description="Enable result caching")
44
+ params: dict[str, Any] = Field(default_factory=dict, description="Strategy-specific parameters")
45
+
46
+ class Config:
47
+ extra = "allow"
48
+
49
+
50
+ class KeywordSelectorConfig(SelectorConfig):
51
+ """Configuration for keyword-based selector."""
52
+
53
+ name: str = "keyword"
54
+ method: str = Field(
55
+ default="tfidf", description="Keyword matching method: tfidf, overlap, bm25"
56
+ )
57
+ lowercase: bool = Field(default=True, description="Convert to lowercase")
58
+ remove_stopwords: bool = Field(default=True, description="Remove stopwords")
59
+ ngram_range: tuple = Field(default=(1, 2), description="N-gram range for features")
60
+
61
+
62
+ class EmbeddingSelectorConfig(SelectorConfig):
63
+ """Configuration for embedding-based selector."""
64
+
65
+ name: str = "embedding"
66
+ embedding_model: str = Field(default="default", description="Embedding model identifier")
67
+ similarity_metric: str = Field(
68
+ default="cosine", description="Similarity metric: cosine, dot, euclidean"
69
+ )
70
+ use_cache: bool = Field(default=True, description="Cache embedding vectors")
71
+ batch_size: int = Field(default=32, ge=1, description="Batch size for embedding")
72
+
73
+
74
+ class TwoStageSelectorConfig(SelectorConfig):
75
+ """Configuration for two-stage selector."""
76
+
77
+ name: str = "two_stage"
78
+ coarse_k: int = Field(
79
+ default=20, ge=1, description="Number of candidates from coarse retrieval"
80
+ )
81
+ coarse_selector: str = Field(default="keyword", description="Coarse retrieval selector")
82
+ rerank_selector: str = Field(default="embedding", description="Reranking selector")
83
+ fusion_weight: float = Field(default=0.5, ge=0.0, le=1.0, description="Weight for score fusion")
84
+
85
+
86
+ class AdaptiveSelectorConfig(SelectorConfig):
87
+ """Configuration for adaptive selector."""
88
+
89
+ name: str = "adaptive"
90
+ strategies: list[str] = Field(
91
+ default_factory=lambda: ["keyword", "embedding"], description="List of strategies"
92
+ )
93
+ selection_method: str = Field(
94
+ default="bandit", description="Selection method: bandit, ensemble, threshold"
95
+ )
96
+ exploration_rate: float = Field(
97
+ default=0.1, ge=0.0, le=1.0, description="Exploration rate for bandit"
98
+ )
99
+ update_interval: int = Field(default=100, ge=1, description="Update interval for adaptation")
100
+
101
+
102
+ class DFSDTSelectorConfig(SelectorConfig):
103
+ """
104
+ Configuration for DFSDT (Depth-First Search-based Decision Tree) selector.
105
+
106
+ Based on ToolLLM paper (Qin et al., 2023):
107
+ "ToolLLM: Facilitating Large Language Models to Master 16000+ Real-world APIs"
108
+ """
109
+
110
+ name: str = "dfsdt"
111
+ max_depth: int = Field(default=3, ge=1, le=10, description="Maximum search depth")
112
+ beam_width: int = Field(default=5, ge=1, le=20, description="Number of candidates per level")
113
+ llm_model: str = Field(
114
+ default="auto", description="LLM model for scoring (auto uses UnifiedInferenceClient)"
115
+ )
116
+ temperature: float = Field(default=0.1, ge=0.0, le=2.0, description="LLM sampling temperature")
117
+ use_diversity_prompt: bool = Field(
118
+ default=True, description="Use diversity prompting for exploration"
119
+ )
120
+ score_threshold: float = Field(
121
+ default=0.3, ge=0.0, le=1.0, description="Minimum score threshold for pruning"
122
+ )
123
+ use_keyword_prefilter: bool = Field(
124
+ default=True, description="Use keyword matching to pre-filter candidates"
125
+ )
126
+ prefilter_k: int = Field(
127
+ default=20, ge=5, le=100, description="Number of candidates after pre-filtering"
128
+ )
129
+
130
+
131
+ class GorillaSelectorConfig(SelectorConfig):
132
+ """
133
+ Configuration for Gorilla-style retrieval-augmented selector.
134
+
135
+ Based on Gorilla paper (Patil et al., 2023):
136
+ "Gorilla: Large Language Model Connected with Massive APIs"
137
+
138
+ Two-stage approach: embedding retrieval + LLM selection.
139
+ """
140
+
141
+ name: str = "gorilla"
142
+ top_k_retrieve: int = Field(
143
+ default=20, ge=1, description="Number of tools to retrieve in first stage"
144
+ )
145
+ top_k_select: int = Field(
146
+ default=5, ge=1, description="Number of tools to select in final output"
147
+ )
148
+ embedding_model: str = Field(default="default", description="Embedding model for retrieval")
149
+ llm_model: str = Field(
150
+ default="auto", description="LLM model for selection (auto uses UnifiedInferenceClient)"
151
+ )
152
+ similarity_metric: str = Field(
153
+ default="cosine", description="Similarity metric: cosine, dot, euclidean"
154
+ )
155
+ temperature: float = Field(
156
+ default=0.1, ge=0.0, le=2.0, description="LLM temperature for selection"
157
+ )
158
+ use_detailed_docs: bool = Field(
159
+ default=True, description="Include detailed parameter docs in context"
160
+ )
161
+ max_context_tools: int = Field(
162
+ default=15, ge=1, description="Max tools to include in LLM context"
163
+ )
164
+
165
+
166
+ # Config type registry
167
+ CONFIG_TYPES = {
168
+ "keyword": KeywordSelectorConfig,
169
+ "embedding": EmbeddingSelectorConfig,
170
+ "two_stage": TwoStageSelectorConfig,
171
+ "adaptive": AdaptiveSelectorConfig,
172
+ "dfsdt": DFSDTSelectorConfig,
173
+ "gorilla": GorillaSelectorConfig,
174
+ }
175
+
176
+
177
+ def create_selector_config(config_dict: dict[str, Any]) -> SelectorConfig:
178
+ """
179
+ Create appropriate selector config from dictionary.
180
+
181
+ Args:
182
+ config_dict: Configuration dictionary
183
+
184
+ Returns:
185
+ Typed SelectorConfig subclass instance
186
+
187
+ Raises:
188
+ ValueError: If selector name not recognized
189
+ """
190
+ selector_name = config_dict.get("name", "keyword")
191
+
192
+ if selector_name not in CONFIG_TYPES:
193
+ raise ValueError(f"Unknown selector type: {selector_name}")
194
+
195
+ config_class = CONFIG_TYPES[selector_name]
196
+ return config_class(**config_dict)