isage-tooluse 0.1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_tooluse-0.1.0.0.dist-info/METADATA +208 -0
- isage_tooluse-0.1.0.0.dist-info/RECORD +14 -0
- isage_tooluse-0.1.0.0.dist-info/WHEEL +5 -0
- isage_tooluse-0.1.0.0.dist-info/licenses/LICENSE +21 -0
- isage_tooluse-0.1.0.0.dist-info/top_level.txt +1 -0
- sage_libs/sage_tooluse/__init__.py +75 -0
- sage_libs/sage_tooluse/base.py +203 -0
- sage_libs/sage_tooluse/dfsdt_selector.py +402 -0
- sage_libs/sage_tooluse/embedding_selector.py +281 -0
- sage_libs/sage_tooluse/gorilla_selector.py +495 -0
- sage_libs/sage_tooluse/hybrid_selector.py +202 -0
- sage_libs/sage_tooluse/keyword_selector.py +270 -0
- sage_libs/sage_tooluse/registry.py +185 -0
- sage_libs/sage_tooluse/schemas.py +196 -0
|
@@ -0,0 +1,495 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Gorilla-style retrieval-augmented tool selector.
|
|
3
|
+
|
|
4
|
+
Implements the retrieval-augmented generation (RAG) approach from Gorilla paper
|
|
5
|
+
for tool selection. Uses embedding retrieval to find relevant API documentation,
|
|
6
|
+
then prompts LLM to make final selection based on retrieved context.
|
|
7
|
+
|
|
8
|
+
Reference:
|
|
9
|
+
Patil et al. (2023) "Gorilla: Large Language Model Connected with Massive APIs"
|
|
10
|
+
https://arxiv.org/abs/2305.15334
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import json
|
|
14
|
+
import logging
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
from typing import Any, Optional
|
|
17
|
+
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
|
|
20
|
+
from .base import BaseToolSelector, SelectorResources
|
|
21
|
+
from .schemas import SelectorConfig, ToolPrediction, ToolSelectionQuery
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class GorillaSelectorConfig(SelectorConfig):
|
|
27
|
+
"""Configuration for Gorilla-style retrieval-augmented selector."""
|
|
28
|
+
|
|
29
|
+
name: str = "gorilla"
|
|
30
|
+
top_k_retrieve: int = Field(
|
|
31
|
+
default=20, ge=1, description="Number of tools to retrieve in first stage"
|
|
32
|
+
)
|
|
33
|
+
top_k_select: int = Field(
|
|
34
|
+
default=5, ge=1, description="Number of tools to select in final output"
|
|
35
|
+
)
|
|
36
|
+
embedding_model: str = Field(default="default", description="Embedding model for retrieval")
|
|
37
|
+
llm_model: str = Field(
|
|
38
|
+
default="auto", description="LLM model for selection (auto uses IntelligentLLMClient)"
|
|
39
|
+
)
|
|
40
|
+
similarity_metric: str = Field(
|
|
41
|
+
default="cosine", description="Similarity metric: cosine, dot, euclidean"
|
|
42
|
+
)
|
|
43
|
+
temperature: float = Field(
|
|
44
|
+
default=0.1, ge=0.0, le=2.0, description="LLM temperature for selection"
|
|
45
|
+
)
|
|
46
|
+
use_detailed_docs: bool = Field(
|
|
47
|
+
default=True, description="Include detailed parameter docs in context"
|
|
48
|
+
)
|
|
49
|
+
max_context_tools: int = Field(
|
|
50
|
+
default=15, ge=1, description="Max tools to include in LLM context"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class RetrievedToolDoc:
|
|
56
|
+
"""Retrieved tool documentation."""
|
|
57
|
+
|
|
58
|
+
tool_id: str
|
|
59
|
+
name: str
|
|
60
|
+
description: str
|
|
61
|
+
retrieval_score: float
|
|
62
|
+
parameters: dict[str, Any] = field(default_factory=dict)
|
|
63
|
+
category: str = ""
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# Sentinel value to indicate auto-creation of LLM client
|
|
67
|
+
_AUTO_LLM = object()
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class GorillaSelector(BaseToolSelector):
|
|
71
|
+
"""
|
|
72
|
+
Gorilla-style retrieval-augmented tool selector.
|
|
73
|
+
|
|
74
|
+
Two-stage approach:
|
|
75
|
+
1. Retrieval: Use embedding similarity to retrieve top-k candidate tools
|
|
76
|
+
2. Selection: Use LLM to analyze retrieved tool docs and select best matches
|
|
77
|
+
|
|
78
|
+
This approach leverages the strengths of both embedding-based retrieval
|
|
79
|
+
(efficient large-scale search) and LLM reasoning (understanding nuanced
|
|
80
|
+
requirements and API semantics).
|
|
81
|
+
|
|
82
|
+
Attributes:
|
|
83
|
+
config: Gorilla selector configuration
|
|
84
|
+
resources: Shared resources (tools_loader, embedding_client)
|
|
85
|
+
llm_client: LLM client for selection stage
|
|
86
|
+
_embedding_selector: Internal embedding selector for retrieval
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
config: GorillaSelectorConfig,
|
|
92
|
+
resources: SelectorResources,
|
|
93
|
+
llm_client: Any = _AUTO_LLM,
|
|
94
|
+
):
|
|
95
|
+
"""
|
|
96
|
+
Initialize Gorilla selector.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
config: Selector configuration
|
|
100
|
+
resources: Shared resources including embedding_client
|
|
101
|
+
llm_client: LLM client for selection. Pass None to disable LLM and use
|
|
102
|
+
retrieval-only mode. Omit or pass _AUTO_LLM for auto-creation.
|
|
103
|
+
|
|
104
|
+
Raises:
|
|
105
|
+
ValueError: If embedding_client is not provided
|
|
106
|
+
"""
|
|
107
|
+
super().__init__(config, resources)
|
|
108
|
+
self.config: GorillaSelectorConfig = config
|
|
109
|
+
|
|
110
|
+
# Validate embedding client
|
|
111
|
+
if not resources.embedding_client:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
"GorillaSelector requires embedding_client in SelectorResources. "
|
|
114
|
+
"Please provide an EmbeddingService instance."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
self.embedding_client = resources.embedding_client
|
|
118
|
+
|
|
119
|
+
# Initialize LLM client:
|
|
120
|
+
# - llm_client=None: explicitly disable LLM, use retrieval-only mode
|
|
121
|
+
# - llm_client=_AUTO_LLM (default): auto-create LLM client
|
|
122
|
+
# - llm_client=<client>: use provided client
|
|
123
|
+
if llm_client is None:
|
|
124
|
+
self.llm_client = None
|
|
125
|
+
elif llm_client is _AUTO_LLM:
|
|
126
|
+
self.llm_client = self._create_llm_client()
|
|
127
|
+
else:
|
|
128
|
+
self.llm_client = llm_client
|
|
129
|
+
|
|
130
|
+
# Build tool index and cache tool metadata
|
|
131
|
+
self._tool_docs: dict[str, RetrievedToolDoc] = {}
|
|
132
|
+
self._tool_embeddings: Optional[Any] = None
|
|
133
|
+
self._tool_ids: list[str] = []
|
|
134
|
+
self._preprocess_tools()
|
|
135
|
+
|
|
136
|
+
def _create_llm_client(self) -> Any:
|
|
137
|
+
"""Create LLM client for selection stage."""
|
|
138
|
+
try:
|
|
139
|
+
from sage.llm import UnifiedInferenceClient
|
|
140
|
+
|
|
141
|
+
# Always use create() for automatic local-first detection
|
|
142
|
+
return UnifiedInferenceClient.create()
|
|
143
|
+
except ImportError:
|
|
144
|
+
logger.warning(
|
|
145
|
+
"UnifiedInferenceClient not available. GorillaSelector will use "
|
|
146
|
+
"embedding-only mode (no LLM reranking)."
|
|
147
|
+
)
|
|
148
|
+
return None
|
|
149
|
+
except Exception as e:
|
|
150
|
+
logger.warning(f"Failed to create LLM client: {e}. Using retrieval-only mode.")
|
|
151
|
+
return None
|
|
152
|
+
|
|
153
|
+
@classmethod
|
|
154
|
+
def from_config(cls, config: SelectorConfig, resources: SelectorResources) -> "GorillaSelector":
|
|
155
|
+
"""Create Gorilla selector from config."""
|
|
156
|
+
if not isinstance(config, GorillaSelectorConfig):
|
|
157
|
+
# Convert generic config to GorillaSelectorConfig
|
|
158
|
+
config = GorillaSelectorConfig(**config.model_dump())
|
|
159
|
+
return cls(config, resources)
|
|
160
|
+
|
|
161
|
+
def _preprocess_tools(self) -> None:
|
|
162
|
+
"""Preprocess all tools and build embeddings index."""
|
|
163
|
+
import numpy as np
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
tools_loader = self.resources.tools_loader
|
|
167
|
+
|
|
168
|
+
# Collect tool metadata
|
|
169
|
+
tool_texts = []
|
|
170
|
+
|
|
171
|
+
for tool in tools_loader.iter_all():
|
|
172
|
+
doc = RetrievedToolDoc(
|
|
173
|
+
tool_id=tool.tool_id,
|
|
174
|
+
name=tool.name,
|
|
175
|
+
description=getattr(tool, "description", ""),
|
|
176
|
+
retrieval_score=0.0,
|
|
177
|
+
parameters=getattr(tool, "parameters", {}),
|
|
178
|
+
category=getattr(tool, "category", ""),
|
|
179
|
+
)
|
|
180
|
+
self._tool_docs[tool.tool_id] = doc
|
|
181
|
+
self._tool_ids.append(tool.tool_id)
|
|
182
|
+
|
|
183
|
+
# Build searchable text
|
|
184
|
+
text = self._build_tool_text(doc)
|
|
185
|
+
tool_texts.append(text)
|
|
186
|
+
|
|
187
|
+
if not tool_texts:
|
|
188
|
+
self.logger.warning("No tools found to preprocess")
|
|
189
|
+
return
|
|
190
|
+
|
|
191
|
+
self.logger.info(f"Embedding {len(tool_texts)} tools for Gorilla retrieval...")
|
|
192
|
+
|
|
193
|
+
# Embed all tools
|
|
194
|
+
embeddings = self.embedding_client.embed(
|
|
195
|
+
texts=tool_texts,
|
|
196
|
+
model=self.config.embedding_model
|
|
197
|
+
if self.config.embedding_model != "default"
|
|
198
|
+
else None,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
self._tool_embeddings = np.asarray(embeddings)
|
|
202
|
+
if self._tool_embeddings.ndim == 1:
|
|
203
|
+
self._tool_embeddings = self._tool_embeddings.reshape(1, -1)
|
|
204
|
+
|
|
205
|
+
self.logger.info(
|
|
206
|
+
f"Built Gorilla index with {len(self._tool_ids)} tools "
|
|
207
|
+
f"(dim={self._tool_embeddings.shape[1]})"
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
except Exception as e:
|
|
211
|
+
self.logger.error(f"Error preprocessing tools for Gorilla: {e}")
|
|
212
|
+
raise
|
|
213
|
+
|
|
214
|
+
def _build_tool_text(self, doc: RetrievedToolDoc) -> str:
|
|
215
|
+
"""Build searchable text from tool documentation."""
|
|
216
|
+
parts = [doc.name]
|
|
217
|
+
|
|
218
|
+
if doc.description:
|
|
219
|
+
parts.append(doc.description)
|
|
220
|
+
|
|
221
|
+
if doc.category:
|
|
222
|
+
parts.append(f"Category: {doc.category}")
|
|
223
|
+
|
|
224
|
+
if self.config.use_detailed_docs and doc.parameters:
|
|
225
|
+
param_desc = []
|
|
226
|
+
for param_name, param_info in doc.parameters.items():
|
|
227
|
+
if isinstance(param_info, dict) and "description" in param_info:
|
|
228
|
+
param_desc.append(f"{param_name}: {param_info['description']}")
|
|
229
|
+
if param_desc:
|
|
230
|
+
parts.append("Parameters: " + "; ".join(param_desc))
|
|
231
|
+
|
|
232
|
+
return " ".join(parts)
|
|
233
|
+
|
|
234
|
+
def _retrieve_candidates(
|
|
235
|
+
self, query: str, candidate_ids: Optional[set[str]], top_k: int
|
|
236
|
+
) -> list[RetrievedToolDoc]:
|
|
237
|
+
"""
|
|
238
|
+
Retrieve candidate tools using embedding similarity.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
query: User instruction
|
|
242
|
+
candidate_ids: Optional set of valid candidate tool IDs
|
|
243
|
+
top_k: Number of candidates to retrieve
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
List of retrieved tool docs with scores
|
|
247
|
+
"""
|
|
248
|
+
import numpy as np
|
|
249
|
+
|
|
250
|
+
if self._tool_embeddings is None:
|
|
251
|
+
return []
|
|
252
|
+
|
|
253
|
+
# Embed query
|
|
254
|
+
query_embedding = self.embedding_client.embed(
|
|
255
|
+
texts=[query],
|
|
256
|
+
model=self.config.embedding_model if self.config.embedding_model != "default" else None,
|
|
257
|
+
)
|
|
258
|
+
query_vector = np.asarray(query_embedding)[0]
|
|
259
|
+
|
|
260
|
+
# Compute similarities
|
|
261
|
+
if self.config.similarity_metric == "cosine":
|
|
262
|
+
# Normalize for cosine similarity
|
|
263
|
+
query_norm = query_vector / (np.linalg.norm(query_vector) + 1e-8)
|
|
264
|
+
tool_norms = self._tool_embeddings / (
|
|
265
|
+
np.linalg.norm(self._tool_embeddings, axis=1, keepdims=True) + 1e-8
|
|
266
|
+
)
|
|
267
|
+
scores = np.dot(tool_norms, query_norm)
|
|
268
|
+
elif self.config.similarity_metric == "dot":
|
|
269
|
+
scores = np.dot(self._tool_embeddings, query_vector)
|
|
270
|
+
else: # euclidean
|
|
271
|
+
distances = np.linalg.norm(self._tool_embeddings - query_vector, axis=1)
|
|
272
|
+
scores = 1.0 / (1.0 + distances)
|
|
273
|
+
|
|
274
|
+
# Filter by candidate_ids if specified
|
|
275
|
+
if candidate_ids:
|
|
276
|
+
valid_indices = [i for i, tid in enumerate(self._tool_ids) if tid in candidate_ids]
|
|
277
|
+
if not valid_indices:
|
|
278
|
+
return []
|
|
279
|
+
filtered_scores = [(i, scores[i]) for i in valid_indices]
|
|
280
|
+
else:
|
|
281
|
+
filtered_scores = list(enumerate(scores))
|
|
282
|
+
|
|
283
|
+
# Sort by score and take top-k
|
|
284
|
+
filtered_scores.sort(key=lambda x: x[1], reverse=True)
|
|
285
|
+
top_results = filtered_scores[:top_k]
|
|
286
|
+
|
|
287
|
+
# Build retrieved docs
|
|
288
|
+
retrieved = []
|
|
289
|
+
for idx, score in top_results:
|
|
290
|
+
tool_id = self._tool_ids[idx]
|
|
291
|
+
doc = self._tool_docs[tool_id]
|
|
292
|
+
doc.retrieval_score = float(score)
|
|
293
|
+
retrieved.append(doc)
|
|
294
|
+
|
|
295
|
+
return retrieved
|
|
296
|
+
|
|
297
|
+
def _build_llm_prompt(
|
|
298
|
+
self, query: str, retrieved_docs: list[RetrievedToolDoc], top_k: int
|
|
299
|
+
) -> str:
|
|
300
|
+
"""Build prompt for LLM selection."""
|
|
301
|
+
# Limit context size
|
|
302
|
+
docs_for_context = retrieved_docs[: self.config.max_context_tools]
|
|
303
|
+
|
|
304
|
+
# Build tool documentation string
|
|
305
|
+
tool_docs_str = []
|
|
306
|
+
for i, doc in enumerate(docs_for_context, 1):
|
|
307
|
+
doc_str = f"{i}. **{doc.name}** (ID: `{doc.tool_id}`)\n"
|
|
308
|
+
doc_str += f" Description: {doc.description}\n"
|
|
309
|
+
if doc.category:
|
|
310
|
+
doc_str += f" Category: {doc.category}\n"
|
|
311
|
+
if doc.parameters and self.config.use_detailed_docs:
|
|
312
|
+
params = []
|
|
313
|
+
for pname, pinfo in list(doc.parameters.items())[:5]: # Limit params
|
|
314
|
+
if isinstance(pinfo, dict):
|
|
315
|
+
ptype = pinfo.get("type", "any")
|
|
316
|
+
pdesc = pinfo.get("description", "")[:100]
|
|
317
|
+
params.append(f"{pname} ({ptype}): {pdesc}")
|
|
318
|
+
if params:
|
|
319
|
+
doc_str += f" Parameters: {'; '.join(params)}\n"
|
|
320
|
+
tool_docs_str.append(doc_str)
|
|
321
|
+
|
|
322
|
+
tools_text = "\n".join(tool_docs_str)
|
|
323
|
+
|
|
324
|
+
prompt = f"""You are an expert API selector. Given a user task and a list of available APIs/tools,
|
|
325
|
+
select the {top_k} most relevant tools that can help complete the task.
|
|
326
|
+
|
|
327
|
+
## User Task
|
|
328
|
+
{query}
|
|
329
|
+
|
|
330
|
+
## Available Tools
|
|
331
|
+
{tools_text}
|
|
332
|
+
|
|
333
|
+
## Instructions
|
|
334
|
+
1. Analyze the user's task requirements carefully
|
|
335
|
+
2. Consider which tools have the capabilities to fulfill the requirements
|
|
336
|
+
3. Select exactly {top_k} tools, ordered by relevance (most relevant first)
|
|
337
|
+
4. Return ONLY a JSON array of tool IDs, no explanation needed
|
|
338
|
+
|
|
339
|
+
## Output Format
|
|
340
|
+
Return a JSON array of tool IDs:
|
|
341
|
+
["tool_id_1", "tool_id_2", ...]
|
|
342
|
+
|
|
343
|
+
## Your Selection (JSON array only):"""
|
|
344
|
+
|
|
345
|
+
return prompt
|
|
346
|
+
|
|
347
|
+
def _parse_llm_response(
|
|
348
|
+
self, response: str, retrieved_docs: list[RetrievedToolDoc]
|
|
349
|
+
) -> list[str]:
|
|
350
|
+
"""Parse LLM response to extract selected tool IDs."""
|
|
351
|
+
# Get valid tool IDs from retrieved docs
|
|
352
|
+
valid_ids = {doc.tool_id for doc in retrieved_docs}
|
|
353
|
+
|
|
354
|
+
# Try to parse JSON array
|
|
355
|
+
response = response.strip()
|
|
356
|
+
|
|
357
|
+
# Remove markdown code block if present
|
|
358
|
+
if response.startswith("```"):
|
|
359
|
+
lines = response.split("\n")
|
|
360
|
+
response = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
|
|
361
|
+
response = response.strip()
|
|
362
|
+
|
|
363
|
+
try:
|
|
364
|
+
selected = json.loads(response)
|
|
365
|
+
if isinstance(selected, list):
|
|
366
|
+
# Filter to only valid IDs
|
|
367
|
+
return [tid for tid in selected if tid in valid_ids]
|
|
368
|
+
except json.JSONDecodeError:
|
|
369
|
+
pass
|
|
370
|
+
|
|
371
|
+
# Fallback: try to extract tool IDs from text
|
|
372
|
+
extracted = []
|
|
373
|
+
for doc in retrieved_docs:
|
|
374
|
+
if doc.tool_id in response or doc.name in response:
|
|
375
|
+
extracted.append(doc.tool_id)
|
|
376
|
+
|
|
377
|
+
return extracted
|
|
378
|
+
|
|
379
|
+
def _select_impl(self, query: ToolSelectionQuery, top_k: int) -> list[ToolPrediction]:
|
|
380
|
+
"""
|
|
381
|
+
Select tools using Gorilla retrieval-augmented approach.
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
query: Tool selection query
|
|
385
|
+
top_k: Number of tools to select
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
List of tool predictions
|
|
389
|
+
"""
|
|
390
|
+
# Filter candidates if specified
|
|
391
|
+
candidate_ids = set(query.candidate_tools) if query.candidate_tools else None
|
|
392
|
+
|
|
393
|
+
# Stage 1: Retrieve candidates using embedding
|
|
394
|
+
retrieve_k = max(self.config.top_k_retrieve, top_k * 3)
|
|
395
|
+
retrieved_docs = self._retrieve_candidates(query.instruction, candidate_ids, retrieve_k)
|
|
396
|
+
|
|
397
|
+
if not retrieved_docs:
|
|
398
|
+
self.logger.warning(f"No tools retrieved for query {query.sample_id}")
|
|
399
|
+
return []
|
|
400
|
+
|
|
401
|
+
# If no LLM client, fall back to retrieval-only
|
|
402
|
+
if self.llm_client is None:
|
|
403
|
+
return self._retrieval_only_select(retrieved_docs, top_k)
|
|
404
|
+
|
|
405
|
+
# Stage 2: LLM selection from retrieved candidates
|
|
406
|
+
try:
|
|
407
|
+
prompt = self._build_llm_prompt(query.instruction, retrieved_docs, top_k)
|
|
408
|
+
|
|
409
|
+
response = self.llm_client.chat(
|
|
410
|
+
messages=[{"role": "user", "content": prompt}],
|
|
411
|
+
temperature=self.config.temperature,
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
selected_ids = self._parse_llm_response(response, retrieved_docs)
|
|
415
|
+
|
|
416
|
+
# Build predictions with scores
|
|
417
|
+
predictions = []
|
|
418
|
+
retrieval_scores = {doc.tool_id: doc.retrieval_score for doc in retrieved_docs}
|
|
419
|
+
|
|
420
|
+
for rank, tool_id in enumerate(selected_ids[:top_k]):
|
|
421
|
+
# Score = combination of LLM rank and retrieval score
|
|
422
|
+
llm_score = 1.0 - (rank / top_k) * 0.5 # 1.0 -> 0.5 based on rank
|
|
423
|
+
retrieval_score = max(0.0, min(1.0, retrieval_scores.get(tool_id, 0.0)))
|
|
424
|
+
combined_score = max(0.0, min(1.0, 0.6 * llm_score + 0.4 * retrieval_score))
|
|
425
|
+
|
|
426
|
+
predictions.append(
|
|
427
|
+
ToolPrediction(
|
|
428
|
+
tool_id=tool_id,
|
|
429
|
+
score=combined_score,
|
|
430
|
+
explanation=f"LLM rank: {rank + 1}, retrieval score: {retrieval_score:.3f}",
|
|
431
|
+
metadata={
|
|
432
|
+
"method": "gorilla",
|
|
433
|
+
"llm_rank": rank + 1,
|
|
434
|
+
"retrieval_score": retrieval_score,
|
|
435
|
+
},
|
|
436
|
+
)
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
# If LLM didn't return enough, supplement with retrieval results
|
|
440
|
+
if len(predictions) < top_k:
|
|
441
|
+
existing_ids = {p.tool_id for p in predictions}
|
|
442
|
+
for doc in retrieved_docs:
|
|
443
|
+
if doc.tool_id not in existing_ids and len(predictions) < top_k:
|
|
444
|
+
# Clamp score to [0, 1] range
|
|
445
|
+
score = max(0.0, min(1.0, doc.retrieval_score * 0.8))
|
|
446
|
+
predictions.append(
|
|
447
|
+
ToolPrediction(
|
|
448
|
+
tool_id=doc.tool_id,
|
|
449
|
+
score=score,
|
|
450
|
+
metadata={
|
|
451
|
+
"method": "gorilla_fallback",
|
|
452
|
+
"retrieval_score": doc.retrieval_score,
|
|
453
|
+
},
|
|
454
|
+
)
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
return predictions
|
|
458
|
+
|
|
459
|
+
except Exception as e:
|
|
460
|
+
self.logger.warning(f"LLM selection failed, falling back to retrieval: {e}")
|
|
461
|
+
return self._retrieval_only_select(retrieved_docs, top_k)
|
|
462
|
+
|
|
463
|
+
def _retrieval_only_select(
|
|
464
|
+
self, retrieved_docs: list[RetrievedToolDoc], top_k: int
|
|
465
|
+
) -> list[ToolPrediction]:
|
|
466
|
+
"""Fallback to retrieval-only selection when LLM unavailable."""
|
|
467
|
+
predictions = []
|
|
468
|
+
for doc in retrieved_docs[:top_k]:
|
|
469
|
+
# Clamp score to [0, 1] range (cosine similarity can be negative)
|
|
470
|
+
score = max(0.0, min(1.0, doc.retrieval_score))
|
|
471
|
+
predictions.append(
|
|
472
|
+
ToolPrediction(
|
|
473
|
+
tool_id=doc.tool_id,
|
|
474
|
+
score=score,
|
|
475
|
+
metadata={
|
|
476
|
+
"method": "gorilla_retrieval_only",
|
|
477
|
+
"retrieval_score": doc.retrieval_score,
|
|
478
|
+
},
|
|
479
|
+
)
|
|
480
|
+
)
|
|
481
|
+
return predictions
|
|
482
|
+
|
|
483
|
+
def get_stats(self) -> dict:
|
|
484
|
+
"""Get selector statistics."""
|
|
485
|
+
stats = super().get_stats()
|
|
486
|
+
stats.update(
|
|
487
|
+
{
|
|
488
|
+
"num_tools": len(self._tool_ids),
|
|
489
|
+
"embedding_model": self.config.embedding_model,
|
|
490
|
+
"llm_model": self.config.llm_model,
|
|
491
|
+
"top_k_retrieve": self.config.top_k_retrieve,
|
|
492
|
+
"has_llm_client": self.llm_client is not None,
|
|
493
|
+
}
|
|
494
|
+
)
|
|
495
|
+
return stats
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Hybrid tool selector.
|
|
3
|
+
|
|
4
|
+
Combines keyword and embedding-based selection strategies using score fusion.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
from .base import BaseToolSelector, SelectorResources
|
|
11
|
+
from .embedding_selector import EmbeddingSelector
|
|
12
|
+
from .keyword_selector import KeywordSelector
|
|
13
|
+
from .schemas import (
|
|
14
|
+
EmbeddingSelectorConfig,
|
|
15
|
+
KeywordSelectorConfig,
|
|
16
|
+
SelectorConfig,
|
|
17
|
+
ToolPrediction,
|
|
18
|
+
ToolSelectionQuery,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class HybridSelectorConfig(SelectorConfig):
|
|
25
|
+
"""Configuration for hybrid selector."""
|
|
26
|
+
|
|
27
|
+
name: str = "hybrid"
|
|
28
|
+
keyword_weight: float = 0.4
|
|
29
|
+
embedding_weight: float = 0.6
|
|
30
|
+
keyword_method: str = "bm25"
|
|
31
|
+
embedding_model: str = "default"
|
|
32
|
+
fusion_method: str = "weighted_sum" # weighted_sum, max, reciprocal_rank
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class HybridSelector(BaseToolSelector):
|
|
36
|
+
"""
|
|
37
|
+
Hybrid tool selector combining keyword and embedding strategies.
|
|
38
|
+
|
|
39
|
+
Uses score fusion to combine results from both approaches:
|
|
40
|
+
- Keyword matching: Fast, works well for exact matches
|
|
41
|
+
- Embedding similarity: Better semantic understanding
|
|
42
|
+
|
|
43
|
+
Fusion methods:
|
|
44
|
+
- weighted_sum: Linear combination of normalized scores
|
|
45
|
+
- max: Maximum score from either method
|
|
46
|
+
- reciprocal_rank: Reciprocal Rank Fusion (RRF)
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, config: HybridSelectorConfig, resources: SelectorResources):
|
|
50
|
+
"""
|
|
51
|
+
Initialize hybrid selector.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
config: Hybrid selector configuration
|
|
55
|
+
resources: Shared resources including embedding_client
|
|
56
|
+
|
|
57
|
+
Note:
|
|
58
|
+
If embedding_client is not available, falls back to keyword-only mode.
|
|
59
|
+
"""
|
|
60
|
+
super().__init__(config, resources)
|
|
61
|
+
self.config: HybridSelectorConfig = config
|
|
62
|
+
|
|
63
|
+
# Initialize keyword selector
|
|
64
|
+
keyword_config = KeywordSelectorConfig(
|
|
65
|
+
name="keyword",
|
|
66
|
+
method=config.keyword_method,
|
|
67
|
+
top_k=config.top_k * 2, # Get more candidates for fusion
|
|
68
|
+
)
|
|
69
|
+
self._keyword_selector = KeywordSelector(keyword_config, resources)
|
|
70
|
+
|
|
71
|
+
# Initialize embedding selector if client available
|
|
72
|
+
self._embedding_selector: Optional[EmbeddingSelector] = None
|
|
73
|
+
self._embedding_available = False
|
|
74
|
+
|
|
75
|
+
if resources.embedding_client:
|
|
76
|
+
try:
|
|
77
|
+
embedding_config = EmbeddingSelectorConfig(
|
|
78
|
+
name="embedding",
|
|
79
|
+
embedding_model=config.embedding_model,
|
|
80
|
+
top_k=config.top_k * 2,
|
|
81
|
+
)
|
|
82
|
+
self._embedding_selector = EmbeddingSelector(embedding_config, resources)
|
|
83
|
+
self._embedding_available = True
|
|
84
|
+
self.logger.info("Hybrid selector: Embedding + Keyword mode")
|
|
85
|
+
except Exception as e:
|
|
86
|
+
self.logger.warning(f"Could not initialize embedding selector: {e}")
|
|
87
|
+
self.logger.info("Hybrid selector: Keyword-only mode")
|
|
88
|
+
else:
|
|
89
|
+
self.logger.info("Hybrid selector: Keyword-only mode (no embedding client)")
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
def from_config(cls, config: SelectorConfig, resources: SelectorResources) -> "HybridSelector":
|
|
93
|
+
"""Create hybrid selector from config."""
|
|
94
|
+
if not isinstance(config, HybridSelectorConfig):
|
|
95
|
+
# Convert generic config to HybridSelectorConfig
|
|
96
|
+
config = HybridSelectorConfig(
|
|
97
|
+
name=config.name,
|
|
98
|
+
top_k=config.top_k,
|
|
99
|
+
min_score=config.min_score,
|
|
100
|
+
cache_enabled=config.cache_enabled,
|
|
101
|
+
params=config.params,
|
|
102
|
+
)
|
|
103
|
+
return cls(config, resources)
|
|
104
|
+
|
|
105
|
+
def _select_impl(self, query: ToolSelectionQuery, top_k: int) -> list[ToolPrediction]:
|
|
106
|
+
"""
|
|
107
|
+
Select tools using hybrid approach.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
query: Tool selection query
|
|
111
|
+
top_k: Number of tools to select
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
List of tool predictions from fused scores
|
|
115
|
+
"""
|
|
116
|
+
# Get keyword results
|
|
117
|
+
keyword_results = self._keyword_selector._select_impl(query, top_k * 2)
|
|
118
|
+
keyword_scores = {p.tool_id: p.score for p in keyword_results}
|
|
119
|
+
|
|
120
|
+
# Get embedding results if available
|
|
121
|
+
embedding_scores = {}
|
|
122
|
+
if self._embedding_available and self._embedding_selector:
|
|
123
|
+
try:
|
|
124
|
+
embedding_results = self._embedding_selector._select_impl(query, top_k * 2)
|
|
125
|
+
embedding_scores = {p.tool_id: p.score for p in embedding_results}
|
|
126
|
+
except Exception as e:
|
|
127
|
+
self.logger.warning(f"Embedding selection failed, using keyword only: {e}")
|
|
128
|
+
|
|
129
|
+
# Fuse scores
|
|
130
|
+
all_tool_ids = set(keyword_scores.keys()) | set(embedding_scores.keys())
|
|
131
|
+
fused_predictions = []
|
|
132
|
+
|
|
133
|
+
for tool_id in all_tool_ids:
|
|
134
|
+
kw_score = keyword_scores.get(tool_id, 0.0)
|
|
135
|
+
emb_score = embedding_scores.get(tool_id, 0.0)
|
|
136
|
+
|
|
137
|
+
if self.config.fusion_method == "weighted_sum":
|
|
138
|
+
# Normalize and combine
|
|
139
|
+
if self._embedding_available:
|
|
140
|
+
final_score = (
|
|
141
|
+
self.config.keyword_weight * kw_score
|
|
142
|
+
+ self.config.embedding_weight * emb_score
|
|
143
|
+
)
|
|
144
|
+
else:
|
|
145
|
+
final_score = kw_score
|
|
146
|
+
|
|
147
|
+
elif self.config.fusion_method == "max":
|
|
148
|
+
final_score = max(kw_score, emb_score)
|
|
149
|
+
|
|
150
|
+
elif self.config.fusion_method == "reciprocal_rank":
|
|
151
|
+
# RRF: 1/(k+rank)
|
|
152
|
+
k = 60 # Standard RRF constant
|
|
153
|
+
kw_rank = self._get_rank(tool_id, keyword_results)
|
|
154
|
+
emb_rank = (
|
|
155
|
+
self._get_rank(tool_id, embedding_results)
|
|
156
|
+
if self._embedding_available
|
|
157
|
+
else float("inf")
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
kw_rrf = 1.0 / (k + kw_rank) if kw_rank < float("inf") else 0
|
|
161
|
+
emb_rrf = 1.0 / (k + emb_rank) if emb_rank < float("inf") else 0
|
|
162
|
+
|
|
163
|
+
final_score = kw_rrf + emb_rrf
|
|
164
|
+
else:
|
|
165
|
+
final_score = kw_score
|
|
166
|
+
|
|
167
|
+
fused_predictions.append(
|
|
168
|
+
ToolPrediction(
|
|
169
|
+
tool_id=tool_id,
|
|
170
|
+
score=min(final_score, 1.0),
|
|
171
|
+
metadata={
|
|
172
|
+
"keyword_score": kw_score,
|
|
173
|
+
"embedding_score": emb_score,
|
|
174
|
+
"fusion_method": self.config.fusion_method,
|
|
175
|
+
},
|
|
176
|
+
)
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Sort by fused score
|
|
180
|
+
fused_predictions.sort(key=lambda p: p.score, reverse=True)
|
|
181
|
+
|
|
182
|
+
return fused_predictions[:top_k]
|
|
183
|
+
|
|
184
|
+
def _get_rank(self, tool_id: str, predictions: list[ToolPrediction]) -> float:
|
|
185
|
+
"""Get rank of tool_id in predictions list (1-indexed)."""
|
|
186
|
+
for i, p in enumerate(predictions):
|
|
187
|
+
if p.tool_id == tool_id:
|
|
188
|
+
return i + 1
|
|
189
|
+
return float("inf")
|
|
190
|
+
|
|
191
|
+
def get_stats(self) -> dict:
|
|
192
|
+
"""Get selector statistics."""
|
|
193
|
+
stats = super().get_stats()
|
|
194
|
+
stats.update(
|
|
195
|
+
{
|
|
196
|
+
"embedding_available": self._embedding_available,
|
|
197
|
+
"fusion_method": self.config.fusion_method,
|
|
198
|
+
"keyword_weight": self.config.keyword_weight,
|
|
199
|
+
"embedding_weight": self.config.embedding_weight,
|
|
200
|
+
}
|
|
201
|
+
)
|
|
202
|
+
return stats
|