isage-tooluse 0.1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,495 @@
1
+ """
2
+ Gorilla-style retrieval-augmented tool selector.
3
+
4
+ Implements the retrieval-augmented generation (RAG) approach from Gorilla paper
5
+ for tool selection. Uses embedding retrieval to find relevant API documentation,
6
+ then prompts LLM to make final selection based on retrieved context.
7
+
8
+ Reference:
9
+ Patil et al. (2023) "Gorilla: Large Language Model Connected with Massive APIs"
10
+ https://arxiv.org/abs/2305.15334
11
+ """
12
+
13
+ import json
14
+ import logging
15
+ from dataclasses import dataclass, field
16
+ from typing import Any, Optional
17
+
18
+ from pydantic import Field
19
+
20
+ from .base import BaseToolSelector, SelectorResources
21
+ from .schemas import SelectorConfig, ToolPrediction, ToolSelectionQuery
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class GorillaSelectorConfig(SelectorConfig):
27
+ """Configuration for Gorilla-style retrieval-augmented selector."""
28
+
29
+ name: str = "gorilla"
30
+ top_k_retrieve: int = Field(
31
+ default=20, ge=1, description="Number of tools to retrieve in first stage"
32
+ )
33
+ top_k_select: int = Field(
34
+ default=5, ge=1, description="Number of tools to select in final output"
35
+ )
36
+ embedding_model: str = Field(default="default", description="Embedding model for retrieval")
37
+ llm_model: str = Field(
38
+ default="auto", description="LLM model for selection (auto uses IntelligentLLMClient)"
39
+ )
40
+ similarity_metric: str = Field(
41
+ default="cosine", description="Similarity metric: cosine, dot, euclidean"
42
+ )
43
+ temperature: float = Field(
44
+ default=0.1, ge=0.0, le=2.0, description="LLM temperature for selection"
45
+ )
46
+ use_detailed_docs: bool = Field(
47
+ default=True, description="Include detailed parameter docs in context"
48
+ )
49
+ max_context_tools: int = Field(
50
+ default=15, ge=1, description="Max tools to include in LLM context"
51
+ )
52
+
53
+
54
+ @dataclass
55
+ class RetrievedToolDoc:
56
+ """Retrieved tool documentation."""
57
+
58
+ tool_id: str
59
+ name: str
60
+ description: str
61
+ retrieval_score: float
62
+ parameters: dict[str, Any] = field(default_factory=dict)
63
+ category: str = ""
64
+
65
+
66
+ # Sentinel value to indicate auto-creation of LLM client
67
+ _AUTO_LLM = object()
68
+
69
+
70
+ class GorillaSelector(BaseToolSelector):
71
+ """
72
+ Gorilla-style retrieval-augmented tool selector.
73
+
74
+ Two-stage approach:
75
+ 1. Retrieval: Use embedding similarity to retrieve top-k candidate tools
76
+ 2. Selection: Use LLM to analyze retrieved tool docs and select best matches
77
+
78
+ This approach leverages the strengths of both embedding-based retrieval
79
+ (efficient large-scale search) and LLM reasoning (understanding nuanced
80
+ requirements and API semantics).
81
+
82
+ Attributes:
83
+ config: Gorilla selector configuration
84
+ resources: Shared resources (tools_loader, embedding_client)
85
+ llm_client: LLM client for selection stage
86
+ _embedding_selector: Internal embedding selector for retrieval
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ config: GorillaSelectorConfig,
92
+ resources: SelectorResources,
93
+ llm_client: Any = _AUTO_LLM,
94
+ ):
95
+ """
96
+ Initialize Gorilla selector.
97
+
98
+ Args:
99
+ config: Selector configuration
100
+ resources: Shared resources including embedding_client
101
+ llm_client: LLM client for selection. Pass None to disable LLM and use
102
+ retrieval-only mode. Omit or pass _AUTO_LLM for auto-creation.
103
+
104
+ Raises:
105
+ ValueError: If embedding_client is not provided
106
+ """
107
+ super().__init__(config, resources)
108
+ self.config: GorillaSelectorConfig = config
109
+
110
+ # Validate embedding client
111
+ if not resources.embedding_client:
112
+ raise ValueError(
113
+ "GorillaSelector requires embedding_client in SelectorResources. "
114
+ "Please provide an EmbeddingService instance."
115
+ )
116
+
117
+ self.embedding_client = resources.embedding_client
118
+
119
+ # Initialize LLM client:
120
+ # - llm_client=None: explicitly disable LLM, use retrieval-only mode
121
+ # - llm_client=_AUTO_LLM (default): auto-create LLM client
122
+ # - llm_client=<client>: use provided client
123
+ if llm_client is None:
124
+ self.llm_client = None
125
+ elif llm_client is _AUTO_LLM:
126
+ self.llm_client = self._create_llm_client()
127
+ else:
128
+ self.llm_client = llm_client
129
+
130
+ # Build tool index and cache tool metadata
131
+ self._tool_docs: dict[str, RetrievedToolDoc] = {}
132
+ self._tool_embeddings: Optional[Any] = None
133
+ self._tool_ids: list[str] = []
134
+ self._preprocess_tools()
135
+
136
+ def _create_llm_client(self) -> Any:
137
+ """Create LLM client for selection stage."""
138
+ try:
139
+ from sage.llm import UnifiedInferenceClient
140
+
141
+ # Always use create() for automatic local-first detection
142
+ return UnifiedInferenceClient.create()
143
+ except ImportError:
144
+ logger.warning(
145
+ "UnifiedInferenceClient not available. GorillaSelector will use "
146
+ "embedding-only mode (no LLM reranking)."
147
+ )
148
+ return None
149
+ except Exception as e:
150
+ logger.warning(f"Failed to create LLM client: {e}. Using retrieval-only mode.")
151
+ return None
152
+
153
+ @classmethod
154
+ def from_config(cls, config: SelectorConfig, resources: SelectorResources) -> "GorillaSelector":
155
+ """Create Gorilla selector from config."""
156
+ if not isinstance(config, GorillaSelectorConfig):
157
+ # Convert generic config to GorillaSelectorConfig
158
+ config = GorillaSelectorConfig(**config.model_dump())
159
+ return cls(config, resources)
160
+
161
+ def _preprocess_tools(self) -> None:
162
+ """Preprocess all tools and build embeddings index."""
163
+ import numpy as np
164
+
165
+ try:
166
+ tools_loader = self.resources.tools_loader
167
+
168
+ # Collect tool metadata
169
+ tool_texts = []
170
+
171
+ for tool in tools_loader.iter_all():
172
+ doc = RetrievedToolDoc(
173
+ tool_id=tool.tool_id,
174
+ name=tool.name,
175
+ description=getattr(tool, "description", ""),
176
+ retrieval_score=0.0,
177
+ parameters=getattr(tool, "parameters", {}),
178
+ category=getattr(tool, "category", ""),
179
+ )
180
+ self._tool_docs[tool.tool_id] = doc
181
+ self._tool_ids.append(tool.tool_id)
182
+
183
+ # Build searchable text
184
+ text = self._build_tool_text(doc)
185
+ tool_texts.append(text)
186
+
187
+ if not tool_texts:
188
+ self.logger.warning("No tools found to preprocess")
189
+ return
190
+
191
+ self.logger.info(f"Embedding {len(tool_texts)} tools for Gorilla retrieval...")
192
+
193
+ # Embed all tools
194
+ embeddings = self.embedding_client.embed(
195
+ texts=tool_texts,
196
+ model=self.config.embedding_model
197
+ if self.config.embedding_model != "default"
198
+ else None,
199
+ )
200
+
201
+ self._tool_embeddings = np.asarray(embeddings)
202
+ if self._tool_embeddings.ndim == 1:
203
+ self._tool_embeddings = self._tool_embeddings.reshape(1, -1)
204
+
205
+ self.logger.info(
206
+ f"Built Gorilla index with {len(self._tool_ids)} tools "
207
+ f"(dim={self._tool_embeddings.shape[1]})"
208
+ )
209
+
210
+ except Exception as e:
211
+ self.logger.error(f"Error preprocessing tools for Gorilla: {e}")
212
+ raise
213
+
214
+ def _build_tool_text(self, doc: RetrievedToolDoc) -> str:
215
+ """Build searchable text from tool documentation."""
216
+ parts = [doc.name]
217
+
218
+ if doc.description:
219
+ parts.append(doc.description)
220
+
221
+ if doc.category:
222
+ parts.append(f"Category: {doc.category}")
223
+
224
+ if self.config.use_detailed_docs and doc.parameters:
225
+ param_desc = []
226
+ for param_name, param_info in doc.parameters.items():
227
+ if isinstance(param_info, dict) and "description" in param_info:
228
+ param_desc.append(f"{param_name}: {param_info['description']}")
229
+ if param_desc:
230
+ parts.append("Parameters: " + "; ".join(param_desc))
231
+
232
+ return " ".join(parts)
233
+
234
+ def _retrieve_candidates(
235
+ self, query: str, candidate_ids: Optional[set[str]], top_k: int
236
+ ) -> list[RetrievedToolDoc]:
237
+ """
238
+ Retrieve candidate tools using embedding similarity.
239
+
240
+ Args:
241
+ query: User instruction
242
+ candidate_ids: Optional set of valid candidate tool IDs
243
+ top_k: Number of candidates to retrieve
244
+
245
+ Returns:
246
+ List of retrieved tool docs with scores
247
+ """
248
+ import numpy as np
249
+
250
+ if self._tool_embeddings is None:
251
+ return []
252
+
253
+ # Embed query
254
+ query_embedding = self.embedding_client.embed(
255
+ texts=[query],
256
+ model=self.config.embedding_model if self.config.embedding_model != "default" else None,
257
+ )
258
+ query_vector = np.asarray(query_embedding)[0]
259
+
260
+ # Compute similarities
261
+ if self.config.similarity_metric == "cosine":
262
+ # Normalize for cosine similarity
263
+ query_norm = query_vector / (np.linalg.norm(query_vector) + 1e-8)
264
+ tool_norms = self._tool_embeddings / (
265
+ np.linalg.norm(self._tool_embeddings, axis=1, keepdims=True) + 1e-8
266
+ )
267
+ scores = np.dot(tool_norms, query_norm)
268
+ elif self.config.similarity_metric == "dot":
269
+ scores = np.dot(self._tool_embeddings, query_vector)
270
+ else: # euclidean
271
+ distances = np.linalg.norm(self._tool_embeddings - query_vector, axis=1)
272
+ scores = 1.0 / (1.0 + distances)
273
+
274
+ # Filter by candidate_ids if specified
275
+ if candidate_ids:
276
+ valid_indices = [i for i, tid in enumerate(self._tool_ids) if tid in candidate_ids]
277
+ if not valid_indices:
278
+ return []
279
+ filtered_scores = [(i, scores[i]) for i in valid_indices]
280
+ else:
281
+ filtered_scores = list(enumerate(scores))
282
+
283
+ # Sort by score and take top-k
284
+ filtered_scores.sort(key=lambda x: x[1], reverse=True)
285
+ top_results = filtered_scores[:top_k]
286
+
287
+ # Build retrieved docs
288
+ retrieved = []
289
+ for idx, score in top_results:
290
+ tool_id = self._tool_ids[idx]
291
+ doc = self._tool_docs[tool_id]
292
+ doc.retrieval_score = float(score)
293
+ retrieved.append(doc)
294
+
295
+ return retrieved
296
+
297
+ def _build_llm_prompt(
298
+ self, query: str, retrieved_docs: list[RetrievedToolDoc], top_k: int
299
+ ) -> str:
300
+ """Build prompt for LLM selection."""
301
+ # Limit context size
302
+ docs_for_context = retrieved_docs[: self.config.max_context_tools]
303
+
304
+ # Build tool documentation string
305
+ tool_docs_str = []
306
+ for i, doc in enumerate(docs_for_context, 1):
307
+ doc_str = f"{i}. **{doc.name}** (ID: `{doc.tool_id}`)\n"
308
+ doc_str += f" Description: {doc.description}\n"
309
+ if doc.category:
310
+ doc_str += f" Category: {doc.category}\n"
311
+ if doc.parameters and self.config.use_detailed_docs:
312
+ params = []
313
+ for pname, pinfo in list(doc.parameters.items())[:5]: # Limit params
314
+ if isinstance(pinfo, dict):
315
+ ptype = pinfo.get("type", "any")
316
+ pdesc = pinfo.get("description", "")[:100]
317
+ params.append(f"{pname} ({ptype}): {pdesc}")
318
+ if params:
319
+ doc_str += f" Parameters: {'; '.join(params)}\n"
320
+ tool_docs_str.append(doc_str)
321
+
322
+ tools_text = "\n".join(tool_docs_str)
323
+
324
+ prompt = f"""You are an expert API selector. Given a user task and a list of available APIs/tools,
325
+ select the {top_k} most relevant tools that can help complete the task.
326
+
327
+ ## User Task
328
+ {query}
329
+
330
+ ## Available Tools
331
+ {tools_text}
332
+
333
+ ## Instructions
334
+ 1. Analyze the user's task requirements carefully
335
+ 2. Consider which tools have the capabilities to fulfill the requirements
336
+ 3. Select exactly {top_k} tools, ordered by relevance (most relevant first)
337
+ 4. Return ONLY a JSON array of tool IDs, no explanation needed
338
+
339
+ ## Output Format
340
+ Return a JSON array of tool IDs:
341
+ ["tool_id_1", "tool_id_2", ...]
342
+
343
+ ## Your Selection (JSON array only):"""
344
+
345
+ return prompt
346
+
347
+ def _parse_llm_response(
348
+ self, response: str, retrieved_docs: list[RetrievedToolDoc]
349
+ ) -> list[str]:
350
+ """Parse LLM response to extract selected tool IDs."""
351
+ # Get valid tool IDs from retrieved docs
352
+ valid_ids = {doc.tool_id for doc in retrieved_docs}
353
+
354
+ # Try to parse JSON array
355
+ response = response.strip()
356
+
357
+ # Remove markdown code block if present
358
+ if response.startswith("```"):
359
+ lines = response.split("\n")
360
+ response = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
361
+ response = response.strip()
362
+
363
+ try:
364
+ selected = json.loads(response)
365
+ if isinstance(selected, list):
366
+ # Filter to only valid IDs
367
+ return [tid for tid in selected if tid in valid_ids]
368
+ except json.JSONDecodeError:
369
+ pass
370
+
371
+ # Fallback: try to extract tool IDs from text
372
+ extracted = []
373
+ for doc in retrieved_docs:
374
+ if doc.tool_id in response or doc.name in response:
375
+ extracted.append(doc.tool_id)
376
+
377
+ return extracted
378
+
379
+ def _select_impl(self, query: ToolSelectionQuery, top_k: int) -> list[ToolPrediction]:
380
+ """
381
+ Select tools using Gorilla retrieval-augmented approach.
382
+
383
+ Args:
384
+ query: Tool selection query
385
+ top_k: Number of tools to select
386
+
387
+ Returns:
388
+ List of tool predictions
389
+ """
390
+ # Filter candidates if specified
391
+ candidate_ids = set(query.candidate_tools) if query.candidate_tools else None
392
+
393
+ # Stage 1: Retrieve candidates using embedding
394
+ retrieve_k = max(self.config.top_k_retrieve, top_k * 3)
395
+ retrieved_docs = self._retrieve_candidates(query.instruction, candidate_ids, retrieve_k)
396
+
397
+ if not retrieved_docs:
398
+ self.logger.warning(f"No tools retrieved for query {query.sample_id}")
399
+ return []
400
+
401
+ # If no LLM client, fall back to retrieval-only
402
+ if self.llm_client is None:
403
+ return self._retrieval_only_select(retrieved_docs, top_k)
404
+
405
+ # Stage 2: LLM selection from retrieved candidates
406
+ try:
407
+ prompt = self._build_llm_prompt(query.instruction, retrieved_docs, top_k)
408
+
409
+ response = self.llm_client.chat(
410
+ messages=[{"role": "user", "content": prompt}],
411
+ temperature=self.config.temperature,
412
+ )
413
+
414
+ selected_ids = self._parse_llm_response(response, retrieved_docs)
415
+
416
+ # Build predictions with scores
417
+ predictions = []
418
+ retrieval_scores = {doc.tool_id: doc.retrieval_score for doc in retrieved_docs}
419
+
420
+ for rank, tool_id in enumerate(selected_ids[:top_k]):
421
+ # Score = combination of LLM rank and retrieval score
422
+ llm_score = 1.0 - (rank / top_k) * 0.5 # 1.0 -> 0.5 based on rank
423
+ retrieval_score = max(0.0, min(1.0, retrieval_scores.get(tool_id, 0.0)))
424
+ combined_score = max(0.0, min(1.0, 0.6 * llm_score + 0.4 * retrieval_score))
425
+
426
+ predictions.append(
427
+ ToolPrediction(
428
+ tool_id=tool_id,
429
+ score=combined_score,
430
+ explanation=f"LLM rank: {rank + 1}, retrieval score: {retrieval_score:.3f}",
431
+ metadata={
432
+ "method": "gorilla",
433
+ "llm_rank": rank + 1,
434
+ "retrieval_score": retrieval_score,
435
+ },
436
+ )
437
+ )
438
+
439
+ # If LLM didn't return enough, supplement with retrieval results
440
+ if len(predictions) < top_k:
441
+ existing_ids = {p.tool_id for p in predictions}
442
+ for doc in retrieved_docs:
443
+ if doc.tool_id not in existing_ids and len(predictions) < top_k:
444
+ # Clamp score to [0, 1] range
445
+ score = max(0.0, min(1.0, doc.retrieval_score * 0.8))
446
+ predictions.append(
447
+ ToolPrediction(
448
+ tool_id=doc.tool_id,
449
+ score=score,
450
+ metadata={
451
+ "method": "gorilla_fallback",
452
+ "retrieval_score": doc.retrieval_score,
453
+ },
454
+ )
455
+ )
456
+
457
+ return predictions
458
+
459
+ except Exception as e:
460
+ self.logger.warning(f"LLM selection failed, falling back to retrieval: {e}")
461
+ return self._retrieval_only_select(retrieved_docs, top_k)
462
+
463
+ def _retrieval_only_select(
464
+ self, retrieved_docs: list[RetrievedToolDoc], top_k: int
465
+ ) -> list[ToolPrediction]:
466
+ """Fallback to retrieval-only selection when LLM unavailable."""
467
+ predictions = []
468
+ for doc in retrieved_docs[:top_k]:
469
+ # Clamp score to [0, 1] range (cosine similarity can be negative)
470
+ score = max(0.0, min(1.0, doc.retrieval_score))
471
+ predictions.append(
472
+ ToolPrediction(
473
+ tool_id=doc.tool_id,
474
+ score=score,
475
+ metadata={
476
+ "method": "gorilla_retrieval_only",
477
+ "retrieval_score": doc.retrieval_score,
478
+ },
479
+ )
480
+ )
481
+ return predictions
482
+
483
+ def get_stats(self) -> dict:
484
+ """Get selector statistics."""
485
+ stats = super().get_stats()
486
+ stats.update(
487
+ {
488
+ "num_tools": len(self._tool_ids),
489
+ "embedding_model": self.config.embedding_model,
490
+ "llm_model": self.config.llm_model,
491
+ "top_k_retrieve": self.config.top_k_retrieve,
492
+ "has_llm_client": self.llm_client is not None,
493
+ }
494
+ )
495
+ return stats
@@ -0,0 +1,202 @@
1
+ """
2
+ Hybrid tool selector.
3
+
4
+ Combines keyword and embedding-based selection strategies using score fusion.
5
+ """
6
+
7
+ import logging
8
+ from typing import Optional
9
+
10
+ from .base import BaseToolSelector, SelectorResources
11
+ from .embedding_selector import EmbeddingSelector
12
+ from .keyword_selector import KeywordSelector
13
+ from .schemas import (
14
+ EmbeddingSelectorConfig,
15
+ KeywordSelectorConfig,
16
+ SelectorConfig,
17
+ ToolPrediction,
18
+ ToolSelectionQuery,
19
+ )
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class HybridSelectorConfig(SelectorConfig):
25
+ """Configuration for hybrid selector."""
26
+
27
+ name: str = "hybrid"
28
+ keyword_weight: float = 0.4
29
+ embedding_weight: float = 0.6
30
+ keyword_method: str = "bm25"
31
+ embedding_model: str = "default"
32
+ fusion_method: str = "weighted_sum" # weighted_sum, max, reciprocal_rank
33
+
34
+
35
+ class HybridSelector(BaseToolSelector):
36
+ """
37
+ Hybrid tool selector combining keyword and embedding strategies.
38
+
39
+ Uses score fusion to combine results from both approaches:
40
+ - Keyword matching: Fast, works well for exact matches
41
+ - Embedding similarity: Better semantic understanding
42
+
43
+ Fusion methods:
44
+ - weighted_sum: Linear combination of normalized scores
45
+ - max: Maximum score from either method
46
+ - reciprocal_rank: Reciprocal Rank Fusion (RRF)
47
+ """
48
+
49
+ def __init__(self, config: HybridSelectorConfig, resources: SelectorResources):
50
+ """
51
+ Initialize hybrid selector.
52
+
53
+ Args:
54
+ config: Hybrid selector configuration
55
+ resources: Shared resources including embedding_client
56
+
57
+ Note:
58
+ If embedding_client is not available, falls back to keyword-only mode.
59
+ """
60
+ super().__init__(config, resources)
61
+ self.config: HybridSelectorConfig = config
62
+
63
+ # Initialize keyword selector
64
+ keyword_config = KeywordSelectorConfig(
65
+ name="keyword",
66
+ method=config.keyword_method,
67
+ top_k=config.top_k * 2, # Get more candidates for fusion
68
+ )
69
+ self._keyword_selector = KeywordSelector(keyword_config, resources)
70
+
71
+ # Initialize embedding selector if client available
72
+ self._embedding_selector: Optional[EmbeddingSelector] = None
73
+ self._embedding_available = False
74
+
75
+ if resources.embedding_client:
76
+ try:
77
+ embedding_config = EmbeddingSelectorConfig(
78
+ name="embedding",
79
+ embedding_model=config.embedding_model,
80
+ top_k=config.top_k * 2,
81
+ )
82
+ self._embedding_selector = EmbeddingSelector(embedding_config, resources)
83
+ self._embedding_available = True
84
+ self.logger.info("Hybrid selector: Embedding + Keyword mode")
85
+ except Exception as e:
86
+ self.logger.warning(f"Could not initialize embedding selector: {e}")
87
+ self.logger.info("Hybrid selector: Keyword-only mode")
88
+ else:
89
+ self.logger.info("Hybrid selector: Keyword-only mode (no embedding client)")
90
+
91
+ @classmethod
92
+ def from_config(cls, config: SelectorConfig, resources: SelectorResources) -> "HybridSelector":
93
+ """Create hybrid selector from config."""
94
+ if not isinstance(config, HybridSelectorConfig):
95
+ # Convert generic config to HybridSelectorConfig
96
+ config = HybridSelectorConfig(
97
+ name=config.name,
98
+ top_k=config.top_k,
99
+ min_score=config.min_score,
100
+ cache_enabled=config.cache_enabled,
101
+ params=config.params,
102
+ )
103
+ return cls(config, resources)
104
+
105
+ def _select_impl(self, query: ToolSelectionQuery, top_k: int) -> list[ToolPrediction]:
106
+ """
107
+ Select tools using hybrid approach.
108
+
109
+ Args:
110
+ query: Tool selection query
111
+ top_k: Number of tools to select
112
+
113
+ Returns:
114
+ List of tool predictions from fused scores
115
+ """
116
+ # Get keyword results
117
+ keyword_results = self._keyword_selector._select_impl(query, top_k * 2)
118
+ keyword_scores = {p.tool_id: p.score for p in keyword_results}
119
+
120
+ # Get embedding results if available
121
+ embedding_scores = {}
122
+ if self._embedding_available and self._embedding_selector:
123
+ try:
124
+ embedding_results = self._embedding_selector._select_impl(query, top_k * 2)
125
+ embedding_scores = {p.tool_id: p.score for p in embedding_results}
126
+ except Exception as e:
127
+ self.logger.warning(f"Embedding selection failed, using keyword only: {e}")
128
+
129
+ # Fuse scores
130
+ all_tool_ids = set(keyword_scores.keys()) | set(embedding_scores.keys())
131
+ fused_predictions = []
132
+
133
+ for tool_id in all_tool_ids:
134
+ kw_score = keyword_scores.get(tool_id, 0.0)
135
+ emb_score = embedding_scores.get(tool_id, 0.0)
136
+
137
+ if self.config.fusion_method == "weighted_sum":
138
+ # Normalize and combine
139
+ if self._embedding_available:
140
+ final_score = (
141
+ self.config.keyword_weight * kw_score
142
+ + self.config.embedding_weight * emb_score
143
+ )
144
+ else:
145
+ final_score = kw_score
146
+
147
+ elif self.config.fusion_method == "max":
148
+ final_score = max(kw_score, emb_score)
149
+
150
+ elif self.config.fusion_method == "reciprocal_rank":
151
+ # RRF: 1/(k+rank)
152
+ k = 60 # Standard RRF constant
153
+ kw_rank = self._get_rank(tool_id, keyword_results)
154
+ emb_rank = (
155
+ self._get_rank(tool_id, embedding_results)
156
+ if self._embedding_available
157
+ else float("inf")
158
+ )
159
+
160
+ kw_rrf = 1.0 / (k + kw_rank) if kw_rank < float("inf") else 0
161
+ emb_rrf = 1.0 / (k + emb_rank) if emb_rank < float("inf") else 0
162
+
163
+ final_score = kw_rrf + emb_rrf
164
+ else:
165
+ final_score = kw_score
166
+
167
+ fused_predictions.append(
168
+ ToolPrediction(
169
+ tool_id=tool_id,
170
+ score=min(final_score, 1.0),
171
+ metadata={
172
+ "keyword_score": kw_score,
173
+ "embedding_score": emb_score,
174
+ "fusion_method": self.config.fusion_method,
175
+ },
176
+ )
177
+ )
178
+
179
+ # Sort by fused score
180
+ fused_predictions.sort(key=lambda p: p.score, reverse=True)
181
+
182
+ return fused_predictions[:top_k]
183
+
184
+ def _get_rank(self, tool_id: str, predictions: list[ToolPrediction]) -> float:
185
+ """Get rank of tool_id in predictions list (1-indexed)."""
186
+ for i, p in enumerate(predictions):
187
+ if p.tool_id == tool_id:
188
+ return i + 1
189
+ return float("inf")
190
+
191
+ def get_stats(self) -> dict:
192
+ """Get selector statistics."""
193
+ stats = super().get_stats()
194
+ stats.update(
195
+ {
196
+ "embedding_available": self._embedding_available,
197
+ "fusion_method": self.config.fusion_method,
198
+ "keyword_weight": self.config.keyword_weight,
199
+ "embedding_weight": self.config.embedding_weight,
200
+ }
201
+ )
202
+ return stats