whichmodel-haystack 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- whichmodel_haystack-0.1.0/.gitignore +9 -0
- whichmodel_haystack-0.1.0/LICENSE.txt +16 -0
- whichmodel_haystack-0.1.0/PKG-INFO +121 -0
- whichmodel_haystack-0.1.0/README.md +90 -0
- whichmodel_haystack-0.1.0/pyproject.toml +69 -0
- whichmodel_haystack-0.1.0/src/haystack_integrations/__init__.py +0 -0
- whichmodel_haystack-0.1.0/src/haystack_integrations/components/__init__.py +0 -0
- whichmodel_haystack-0.1.0/src/haystack_integrations/components/routers/__init__.py +0 -0
- whichmodel_haystack-0.1.0/src/haystack_integrations/components/routers/whichmodel/__init__.py +5 -0
- whichmodel_haystack-0.1.0/src/haystack_integrations/components/routers/whichmodel/py.typed +0 -0
- whichmodel_haystack-0.1.0/src/haystack_integrations/components/routers/whichmodel/router.py +295 -0
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
|
|
2
|
+
Apache License
|
|
3
|
+
Version 2.0, January 2004
|
|
4
|
+
http://www.apache.org/licenses/
|
|
5
|
+
|
|
6
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
you may not use this file except in compliance with the License.
|
|
8
|
+
You may obtain a copy of the License at
|
|
9
|
+
|
|
10
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
|
|
12
|
+
Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
See the License for the specific language governing permissions and
|
|
16
|
+
limitations under the License.
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: whichmodel-haystack
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Haystack integration for WhichModel — cost-aware LLM model selection
|
|
5
|
+
Project-URL: Homepage, https://whichmodel.dev
|
|
6
|
+
Project-URL: Repository, https://github.com/Which-Model/whichmodel-mcp
|
|
7
|
+
Project-URL: Documentation, https://github.com/Which-Model/whichmodel-mcp/tree/main/integrations/haystack-whichmodel
|
|
8
|
+
Author-email: WhichModel <hello@whichmodel.dev>
|
|
9
|
+
License: Apache-2.0
|
|
10
|
+
License-File: LICENSE.txt
|
|
11
|
+
Keywords: cost-optimization,haystack,llm,mcp,model-selection,whichmodel
|
|
12
|
+
Classifier: Development Status :: 4 - Beta
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
21
|
+
Requires-Python: >=3.10
|
|
22
|
+
Requires-Dist: haystack-ai>=2.19.0
|
|
23
|
+
Requires-Dist: httpx>=0.27.0
|
|
24
|
+
Provides-Extra: dev
|
|
25
|
+
Requires-Dist: mypy>=1.10; extra == 'dev'
|
|
26
|
+
Requires-Dist: pytest-asyncio; extra == 'dev'
|
|
27
|
+
Requires-Dist: pytest-cov; extra == 'dev'
|
|
28
|
+
Requires-Dist: pytest>=7.0; extra == 'dev'
|
|
29
|
+
Requires-Dist: ruff>=0.4.0; extra == 'dev'
|
|
30
|
+
Description-Content-Type: text/markdown
|
|
31
|
+
|
|
32
|
+
# whichmodel-haystack
|
|
33
|
+
|
|
34
|
+
[](https://pypi.org/project/whichmodel-haystack/)
|
|
35
|
+
[](https://pypi.org/project/whichmodel-haystack/)
|
|
36
|
+
|
|
37
|
+
Haystack integration for [WhichModel](https://whichmodel.dev) — cost-aware LLM model selection for your pipelines.
|
|
38
|
+
|
|
39
|
+
## Installation
|
|
40
|
+
|
|
41
|
+
```bash
|
|
42
|
+
pip install whichmodel-haystack
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
## Quick Start
|
|
46
|
+
|
|
47
|
+
```python
|
|
48
|
+
from haystack_integrations.components.routers.whichmodel import WhichModelRouter
|
|
49
|
+
|
|
50
|
+
router = WhichModelRouter()
|
|
51
|
+
result = router.run(task_type="code_generation", complexity="high")
|
|
52
|
+
|
|
53
|
+
print(result["model_id"]) # e.g. "anthropic/claude-sonnet-4"
|
|
54
|
+
print(result["provider"]) # e.g. "anthropic"
|
|
55
|
+
print(result["confidence"]) # "high", "medium", or "low"
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
No API key required. The component calls the public WhichModel MCP server at `https://whichmodel.dev/mcp`.
|
|
59
|
+
|
|
60
|
+
## Usage in a Pipeline
|
|
61
|
+
|
|
62
|
+
Use `WhichModelRouter` to dynamically select the best model before generating:
|
|
63
|
+
|
|
64
|
+
```python
|
|
65
|
+
from haystack import Pipeline
|
|
66
|
+
from haystack.components.generators.chat import OpenAIChatGenerator
|
|
67
|
+
from haystack_integrations.components.routers.whichmodel import WhichModelRouter
|
|
68
|
+
|
|
69
|
+
# Get the best model for the task
|
|
70
|
+
router = WhichModelRouter()
|
|
71
|
+
result = router.run(
|
|
72
|
+
task_type="code_generation",
|
|
73
|
+
complexity="high",
|
|
74
|
+
estimated_input_tokens=2000,
|
|
75
|
+
estimated_output_tokens=1000,
|
|
76
|
+
budget_per_call=0.01,
|
|
77
|
+
requirements={"tool_calling": True},
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Use the recommended model
|
|
81
|
+
print(f"Using {result['model_id']} (confidence: {result['confidence']})")
|
|
82
|
+
print(f"Estimated cost: ${result['recommendation']['cost_estimate_usd']:.6f}")
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
## Parameters
|
|
86
|
+
|
|
87
|
+
### Init Parameters
|
|
88
|
+
|
|
89
|
+
| Parameter | Type | Default | Description |
|
|
90
|
+
|-----------|------|---------|-------------|
|
|
91
|
+
| `mcp_endpoint` | `str` | `https://whichmodel.dev/mcp` | WhichModel MCP server URL |
|
|
92
|
+
| `timeout` | `float` | `30.0` | HTTP request timeout in seconds |
|
|
93
|
+
| `default_task_type` | `str` | `None` | Default task type for `run()` |
|
|
94
|
+
| `default_complexity` | `str` | `"medium"` | Default complexity level |
|
|
95
|
+
|
|
96
|
+
### Run Parameters
|
|
97
|
+
|
|
98
|
+
| Parameter | Type | Description |
|
|
99
|
+
|-----------|------|-------------|
|
|
100
|
+
| `task_type` | `str` | Task type: `chat`, `code_generation`, `code_review`, `summarisation`, `translation`, `data_extraction`, `tool_calling`, `creative_writing`, `research`, `classification`, `embedding`, `vision`, `reasoning` |
|
|
101
|
+
| `complexity` | `str` | `"low"`, `"medium"`, or `"high"` |
|
|
102
|
+
| `estimated_input_tokens` | `int` | Expected input size in tokens |
|
|
103
|
+
| `estimated_output_tokens` | `int` | Expected output size in tokens |
|
|
104
|
+
| `budget_per_call` | `float` | Max USD per call |
|
|
105
|
+
| `requirements` | `dict` | Capability requirements: `tool_calling`, `json_output`, `streaming`, `context_window_min`, `providers_include`, `providers_exclude` |
|
|
106
|
+
|
|
107
|
+
### Output
|
|
108
|
+
|
|
109
|
+
| Key | Type | Description |
|
|
110
|
+
|-----|------|-------------|
|
|
111
|
+
| `model_id` | `str` | Recommended model ID (e.g. `anthropic/claude-sonnet-4`) |
|
|
112
|
+
| `provider` | `str` | Provider name |
|
|
113
|
+
| `recommendation` | `dict` | Full recommendation with score, reasoning, pricing |
|
|
114
|
+
| `alternative` | `dict` | Alternative model from different provider/tier |
|
|
115
|
+
| `budget_model` | `dict` | Cheapest viable option |
|
|
116
|
+
| `confidence` | `str` | `"high"`, `"medium"`, or `"low"` |
|
|
117
|
+
| `data_freshness` | `str` | When pricing data was last updated |
|
|
118
|
+
|
|
119
|
+
## License
|
|
120
|
+
|
|
121
|
+
Apache-2.0
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# whichmodel-haystack
|
|
2
|
+
|
|
3
|
+
[](https://pypi.org/project/whichmodel-haystack/)
|
|
4
|
+
[](https://pypi.org/project/whichmodel-haystack/)
|
|
5
|
+
|
|
6
|
+
Haystack integration for [WhichModel](https://whichmodel.dev) — cost-aware LLM model selection for your pipelines.
|
|
7
|
+
|
|
8
|
+
## Installation
|
|
9
|
+
|
|
10
|
+
```bash
|
|
11
|
+
pip install whichmodel-haystack
|
|
12
|
+
```
|
|
13
|
+
|
|
14
|
+
## Quick Start
|
|
15
|
+
|
|
16
|
+
```python
|
|
17
|
+
from haystack_integrations.components.routers.whichmodel import WhichModelRouter
|
|
18
|
+
|
|
19
|
+
router = WhichModelRouter()
|
|
20
|
+
result = router.run(task_type="code_generation", complexity="high")
|
|
21
|
+
|
|
22
|
+
print(result["model_id"]) # e.g. "anthropic/claude-sonnet-4"
|
|
23
|
+
print(result["provider"]) # e.g. "anthropic"
|
|
24
|
+
print(result["confidence"]) # "high", "medium", or "low"
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
No API key required. The component calls the public WhichModel MCP server at `https://whichmodel.dev/mcp`.
|
|
28
|
+
|
|
29
|
+
## Usage in a Pipeline
|
|
30
|
+
|
|
31
|
+
Use `WhichModelRouter` to dynamically select the best model before generating:
|
|
32
|
+
|
|
33
|
+
```python
|
|
34
|
+
from haystack import Pipeline
|
|
35
|
+
from haystack.components.generators.chat import OpenAIChatGenerator
|
|
36
|
+
from haystack_integrations.components.routers.whichmodel import WhichModelRouter
|
|
37
|
+
|
|
38
|
+
# Get the best model for the task
|
|
39
|
+
router = WhichModelRouter()
|
|
40
|
+
result = router.run(
|
|
41
|
+
task_type="code_generation",
|
|
42
|
+
complexity="high",
|
|
43
|
+
estimated_input_tokens=2000,
|
|
44
|
+
estimated_output_tokens=1000,
|
|
45
|
+
budget_per_call=0.01,
|
|
46
|
+
requirements={"tool_calling": True},
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Use the recommended model
|
|
50
|
+
print(f"Using {result['model_id']} (confidence: {result['confidence']})")
|
|
51
|
+
print(f"Estimated cost: ${result['recommendation']['cost_estimate_usd']:.6f}")
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
## Parameters
|
|
55
|
+
|
|
56
|
+
### Init Parameters
|
|
57
|
+
|
|
58
|
+
| Parameter | Type | Default | Description |
|
|
59
|
+
|-----------|------|---------|-------------|
|
|
60
|
+
| `mcp_endpoint` | `str` | `https://whichmodel.dev/mcp` | WhichModel MCP server URL |
|
|
61
|
+
| `timeout` | `float` | `30.0` | HTTP request timeout in seconds |
|
|
62
|
+
| `default_task_type` | `str` | `None` | Default task type for `run()` |
|
|
63
|
+
| `default_complexity` | `str` | `"medium"` | Default complexity level |
|
|
64
|
+
|
|
65
|
+
### Run Parameters
|
|
66
|
+
|
|
67
|
+
| Parameter | Type | Description |
|
|
68
|
+
|-----------|------|-------------|
|
|
69
|
+
| `task_type` | `str` | Task type: `chat`, `code_generation`, `code_review`, `summarisation`, `translation`, `data_extraction`, `tool_calling`, `creative_writing`, `research`, `classification`, `embedding`, `vision`, `reasoning` |
|
|
70
|
+
| `complexity` | `str` | `"low"`, `"medium"`, or `"high"` |
|
|
71
|
+
| `estimated_input_tokens` | `int` | Expected input size in tokens |
|
|
72
|
+
| `estimated_output_tokens` | `int` | Expected output size in tokens |
|
|
73
|
+
| `budget_per_call` | `float` | Max USD per call |
|
|
74
|
+
| `requirements` | `dict` | Capability requirements: `tool_calling`, `json_output`, `streaming`, `context_window_min`, `providers_include`, `providers_exclude` |
|
|
75
|
+
|
|
76
|
+
### Output
|
|
77
|
+
|
|
78
|
+
| Key | Type | Description |
|
|
79
|
+
|-----|------|-------------|
|
|
80
|
+
| `model_id` | `str` | Recommended model ID (e.g. `anthropic/claude-sonnet-4`) |
|
|
81
|
+
| `provider` | `str` | Provider name |
|
|
82
|
+
| `recommendation` | `dict` | Full recommendation with score, reasoning, pricing |
|
|
83
|
+
| `alternative` | `dict` | Alternative model from different provider/tier |
|
|
84
|
+
| `budget_model` | `dict` | Cheapest viable option |
|
|
85
|
+
| `confidence` | `str` | `"high"`, `"medium"`, or `"low"` |
|
|
86
|
+
| `data_freshness` | `str` | When pricing data was last updated |
|
|
87
|
+
|
|
88
|
+
## License
|
|
89
|
+
|
|
90
|
+
Apache-2.0
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling", "hatch-vcs"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "whichmodel-haystack"
|
|
7
|
+
dynamic = ["version"]
|
|
8
|
+
description = "Haystack integration for WhichModel — cost-aware LLM model selection"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
|
+
license = {text = "Apache-2.0"}
|
|
12
|
+
keywords = ["haystack", "whichmodel", "llm", "model-selection", "cost-optimization", "mcp"]
|
|
13
|
+
authors = [
|
|
14
|
+
{name = "WhichModel", email = "hello@whichmodel.dev"},
|
|
15
|
+
]
|
|
16
|
+
classifiers = [
|
|
17
|
+
"Development Status :: 4 - Beta",
|
|
18
|
+
"Intended Audience :: Developers",
|
|
19
|
+
"License :: OSI Approved :: Apache Software License",
|
|
20
|
+
"Programming Language :: Python :: 3",
|
|
21
|
+
"Programming Language :: Python :: 3.10",
|
|
22
|
+
"Programming Language :: Python :: 3.11",
|
|
23
|
+
"Programming Language :: Python :: 3.12",
|
|
24
|
+
"Programming Language :: Python :: 3.13",
|
|
25
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
dependencies = [
|
|
29
|
+
"haystack-ai>=2.19.0",
|
|
30
|
+
"httpx>=0.27.0",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
[project.optional-dependencies]
|
|
34
|
+
dev = [
|
|
35
|
+
"pytest>=7.0",
|
|
36
|
+
"pytest-asyncio",
|
|
37
|
+
"pytest-cov",
|
|
38
|
+
"ruff>=0.4.0",
|
|
39
|
+
"mypy>=1.10",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
[project.urls]
|
|
43
|
+
Homepage = "https://whichmodel.dev"
|
|
44
|
+
Repository = "https://github.com/Which-Model/whichmodel-mcp"
|
|
45
|
+
Documentation = "https://github.com/Which-Model/whichmodel-mcp/tree/main/integrations/haystack-whichmodel"
|
|
46
|
+
|
|
47
|
+
[tool.hatch.version]
|
|
48
|
+
source = "vcs"
|
|
49
|
+
tag-pattern = "integrations/haystack-whichmodel-v(?P<version>.*)"
|
|
50
|
+
fallback-version = "0.1.0"
|
|
51
|
+
|
|
52
|
+
[tool.hatch.build.targets.sdist]
|
|
53
|
+
include = ["src/haystack_integrations"]
|
|
54
|
+
|
|
55
|
+
[tool.hatch.build.targets.wheel]
|
|
56
|
+
packages = ["src/haystack_integrations"]
|
|
57
|
+
|
|
58
|
+
[tool.ruff]
|
|
59
|
+
line-length = 120
|
|
60
|
+
|
|
61
|
+
[tool.ruff.lint]
|
|
62
|
+
select = ["E", "F", "W", "I", "UP"]
|
|
63
|
+
|
|
64
|
+
[tool.mypy]
|
|
65
|
+
strict = true
|
|
66
|
+
warn_return_any = true
|
|
67
|
+
|
|
68
|
+
[tool.pytest.ini_options]
|
|
69
|
+
testpaths = ["tests"]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
"""WhichModel router component for Haystack pipelines.
|
|
2
|
+
|
|
3
|
+
Calls the WhichModel MCP server to get cost-optimised model recommendations,
|
|
4
|
+
enabling dynamic model selection in Haystack pipelines.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
from typing import Any, Literal
|
|
12
|
+
|
|
13
|
+
import httpx
|
|
14
|
+
from haystack import component, default_from_dict, default_to_dict
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
TASK_TYPES = [
|
|
19
|
+
"chat",
|
|
20
|
+
"code_generation",
|
|
21
|
+
"code_review",
|
|
22
|
+
"summarisation",
|
|
23
|
+
"translation",
|
|
24
|
+
"data_extraction",
|
|
25
|
+
"tool_calling",
|
|
26
|
+
"creative_writing",
|
|
27
|
+
"research",
|
|
28
|
+
"classification",
|
|
29
|
+
"embedding",
|
|
30
|
+
"vision",
|
|
31
|
+
"reasoning",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
TaskType = Literal[
|
|
35
|
+
"chat",
|
|
36
|
+
"code_generation",
|
|
37
|
+
"code_review",
|
|
38
|
+
"summarisation",
|
|
39
|
+
"translation",
|
|
40
|
+
"data_extraction",
|
|
41
|
+
"tool_calling",
|
|
42
|
+
"creative_writing",
|
|
43
|
+
"research",
|
|
44
|
+
"classification",
|
|
45
|
+
"embedding",
|
|
46
|
+
"vision",
|
|
47
|
+
"reasoning",
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
Complexity = Literal["low", "medium", "high"]
|
|
51
|
+
|
|
52
|
+
_MCP_ENDPOINT = "https://whichmodel.dev/mcp"
|
|
53
|
+
_REQUEST_ID_COUNTER = 0
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _next_request_id() -> int:
|
|
57
|
+
global _REQUEST_ID_COUNTER
|
|
58
|
+
_REQUEST_ID_COUNTER += 1
|
|
59
|
+
return _REQUEST_ID_COUNTER
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@component
|
|
63
|
+
class WhichModelRouter:
|
|
64
|
+
"""Select the best LLM for a task using WhichModel's cost-aware recommendation engine.
|
|
65
|
+
|
|
66
|
+
This component calls the WhichModel MCP server to get model recommendations
|
|
67
|
+
based on task type, complexity, token estimates, budget, and capability requirements.
|
|
68
|
+
It returns the recommended model ID and full recommendation details that can be used
|
|
69
|
+
to dynamically route to the best model in a Haystack pipeline.
|
|
70
|
+
|
|
71
|
+
Usage:
|
|
72
|
+
|
|
73
|
+
```python
|
|
74
|
+
from haystack_integrations.components.routers.whichmodel import WhichModelRouter
|
|
75
|
+
|
|
76
|
+
router = WhichModelRouter()
|
|
77
|
+
result = router.run(task_type="code_generation", complexity="high")
|
|
78
|
+
print(result["model_id"]) # e.g. "anthropic/claude-sonnet-4"
|
|
79
|
+
print(result["recommendation"]) # full recommendation dict
|
|
80
|
+
```
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
mcp_endpoint: str = _MCP_ENDPOINT,
|
|
86
|
+
timeout: float = 30.0,
|
|
87
|
+
default_task_type: TaskType | None = None,
|
|
88
|
+
default_complexity: Complexity = "medium",
|
|
89
|
+
) -> None:
|
|
90
|
+
"""Initialize the WhichModel router.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
mcp_endpoint: URL of the WhichModel MCP server.
|
|
94
|
+
timeout: HTTP request timeout in seconds.
|
|
95
|
+
default_task_type: Default task type when not provided at runtime.
|
|
96
|
+
default_complexity: Default complexity level.
|
|
97
|
+
"""
|
|
98
|
+
self.mcp_endpoint = mcp_endpoint
|
|
99
|
+
self.timeout = timeout
|
|
100
|
+
self.default_task_type = default_task_type
|
|
101
|
+
self.default_complexity = default_complexity
|
|
102
|
+
self._client: httpx.Client | None = None
|
|
103
|
+
self._session_id: str | None = None
|
|
104
|
+
|
|
105
|
+
def warm_up(self) -> None:
|
|
106
|
+
"""Initialize the HTTP client and open an MCP session."""
|
|
107
|
+
if self._client is None:
|
|
108
|
+
self._client = httpx.Client(timeout=self.timeout)
|
|
109
|
+
self._initialize_session()
|
|
110
|
+
|
|
111
|
+
def _initialize_session(self) -> None:
|
|
112
|
+
"""Send MCP initialize request to establish a session."""
|
|
113
|
+
assert self._client is not None
|
|
114
|
+
payload = {
|
|
115
|
+
"jsonrpc": "2.0",
|
|
116
|
+
"id": _next_request_id(),
|
|
117
|
+
"method": "initialize",
|
|
118
|
+
"params": {
|
|
119
|
+
"protocolVersion": "2025-03-26",
|
|
120
|
+
"capabilities": {},
|
|
121
|
+
"clientInfo": {"name": "haystack-whichmodel", "version": "0.1.0"},
|
|
122
|
+
},
|
|
123
|
+
}
|
|
124
|
+
try:
|
|
125
|
+
resp = self._client.post(
|
|
126
|
+
self.mcp_endpoint,
|
|
127
|
+
json=payload,
|
|
128
|
+
headers={"Content-Type": "application/json", "Accept": "application/json, text/event-stream"},
|
|
129
|
+
)
|
|
130
|
+
resp.raise_for_status()
|
|
131
|
+
|
|
132
|
+
# Capture session ID from Mcp-Session-Id header
|
|
133
|
+
self._session_id = resp.headers.get("mcp-session-id")
|
|
134
|
+
|
|
135
|
+
# Parse response (may be SSE or JSON)
|
|
136
|
+
result = self._parse_response(resp)
|
|
137
|
+
logger.debug("MCP session initialized: %s", result)
|
|
138
|
+
except httpx.HTTPError as e:
|
|
139
|
+
logger.warning("Failed to initialize MCP session: %s", e)
|
|
140
|
+
|
|
141
|
+
def _parse_response(self, resp: httpx.Response) -> dict[str, Any]:
|
|
142
|
+
"""Parse an MCP response, handling both JSON and SSE formats."""
|
|
143
|
+
content_type = resp.headers.get("content-type", "")
|
|
144
|
+
|
|
145
|
+
if "text/event-stream" in content_type:
|
|
146
|
+
# Parse SSE: look for data lines with JSON-RPC response
|
|
147
|
+
for line in resp.text.splitlines():
|
|
148
|
+
if line.startswith("data: "):
|
|
149
|
+
data = line[6:].strip()
|
|
150
|
+
if data:
|
|
151
|
+
parsed = json.loads(data)
|
|
152
|
+
if "result" in parsed or "error" in parsed:
|
|
153
|
+
return parsed
|
|
154
|
+
return {}
|
|
155
|
+
else:
|
|
156
|
+
return resp.json()
|
|
157
|
+
|
|
158
|
+
def _call_tool(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
|
159
|
+
"""Call an MCP tool and return the result."""
|
|
160
|
+
if self._client is None:
|
|
161
|
+
self.warm_up()
|
|
162
|
+
assert self._client is not None
|
|
163
|
+
|
|
164
|
+
payload = {
|
|
165
|
+
"jsonrpc": "2.0",
|
|
166
|
+
"id": _next_request_id(),
|
|
167
|
+
"method": "tools/call",
|
|
168
|
+
"params": {
|
|
169
|
+
"name": tool_name,
|
|
170
|
+
"arguments": arguments,
|
|
171
|
+
},
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
headers: dict[str, str] = {
|
|
175
|
+
"Content-Type": "application/json",
|
|
176
|
+
"Accept": "application/json, text/event-stream",
|
|
177
|
+
}
|
|
178
|
+
if self._session_id:
|
|
179
|
+
headers["Mcp-Session-Id"] = self._session_id
|
|
180
|
+
|
|
181
|
+
resp = self._client.post(self.mcp_endpoint, json=payload, headers=headers)
|
|
182
|
+
resp.raise_for_status()
|
|
183
|
+
|
|
184
|
+
result = self._parse_response(resp)
|
|
185
|
+
|
|
186
|
+
if "error" in result:
|
|
187
|
+
raise RuntimeError(f"MCP tool call failed: {result['error']}")
|
|
188
|
+
|
|
189
|
+
return result.get("result", {})
|
|
190
|
+
|
|
191
|
+
@component.output_types(
|
|
192
|
+
model_id=str,
|
|
193
|
+
provider=str,
|
|
194
|
+
recommendation=dict,
|
|
195
|
+
alternative=dict,
|
|
196
|
+
budget_model=dict,
|
|
197
|
+
confidence=str,
|
|
198
|
+
data_freshness=str,
|
|
199
|
+
)
|
|
200
|
+
def run(
|
|
201
|
+
self,
|
|
202
|
+
task_type: str | None = None,
|
|
203
|
+
complexity: str | None = None,
|
|
204
|
+
estimated_input_tokens: int | None = None,
|
|
205
|
+
estimated_output_tokens: int | None = None,
|
|
206
|
+
budget_per_call: float | None = None,
|
|
207
|
+
requirements: dict[str, Any] | None = None,
|
|
208
|
+
) -> dict[str, Any]:
|
|
209
|
+
"""Get a cost-optimised model recommendation from WhichModel.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
task_type: The type of task (e.g. "code_generation", "chat", "reasoning").
|
|
213
|
+
complexity: Task complexity: "low", "medium", or "high".
|
|
214
|
+
estimated_input_tokens: Expected input size in tokens.
|
|
215
|
+
estimated_output_tokens: Expected output size in tokens.
|
|
216
|
+
budget_per_call: Maximum spend in USD for a single call.
|
|
217
|
+
requirements: Capability requirements dict with keys like
|
|
218
|
+
``tool_calling``, ``json_output``, ``streaming``,
|
|
219
|
+
``context_window_min``, ``providers_include``, ``providers_exclude``.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
Dictionary with keys:
|
|
223
|
+
- ``model_id``: Recommended model identifier (e.g. "anthropic/claude-sonnet-4").
|
|
224
|
+
- ``provider``: Model provider name.
|
|
225
|
+
- ``recommendation``: Full recommendation details dict.
|
|
226
|
+
- ``alternative``: Alternative model recommendation (empty dict if none).
|
|
227
|
+
- ``budget_model``: Cheapest viable model (empty dict if none).
|
|
228
|
+
- ``confidence``: Recommendation confidence ("high", "medium", "low").
|
|
229
|
+
- ``data_freshness``: When pricing data was last updated.
|
|
230
|
+
"""
|
|
231
|
+
effective_task_type = task_type or self.default_task_type
|
|
232
|
+
if not effective_task_type:
|
|
233
|
+
raise ValueError(
|
|
234
|
+
"task_type must be provided either at init (default_task_type) or at runtime. "
|
|
235
|
+
f"Valid types: {', '.join(TASK_TYPES)}"
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
effective_complexity = complexity or self.default_complexity
|
|
239
|
+
|
|
240
|
+
# Build MCP tool arguments
|
|
241
|
+
arguments: dict[str, Any] = {
|
|
242
|
+
"task_type": effective_task_type,
|
|
243
|
+
"complexity": effective_complexity,
|
|
244
|
+
}
|
|
245
|
+
if estimated_input_tokens is not None:
|
|
246
|
+
arguments["estimated_input_tokens"] = estimated_input_tokens
|
|
247
|
+
if estimated_output_tokens is not None:
|
|
248
|
+
arguments["estimated_output_tokens"] = estimated_output_tokens
|
|
249
|
+
if budget_per_call is not None:
|
|
250
|
+
arguments["budget_per_call"] = budget_per_call
|
|
251
|
+
if requirements is not None:
|
|
252
|
+
arguments["requirements"] = requirements
|
|
253
|
+
|
|
254
|
+
result = self._call_tool("recommend_model", arguments)
|
|
255
|
+
|
|
256
|
+
# Extract the text content from MCP tool result
|
|
257
|
+
content = result.get("content", [])
|
|
258
|
+
if not content:
|
|
259
|
+
raise RuntimeError("Empty response from WhichModel MCP server")
|
|
260
|
+
|
|
261
|
+
text_content = next((c["text"] for c in content if c.get("type") == "text"), None)
|
|
262
|
+
if text_content is None:
|
|
263
|
+
raise RuntimeError("No text content in WhichModel MCP response")
|
|
264
|
+
|
|
265
|
+
data = json.loads(text_content)
|
|
266
|
+
|
|
267
|
+
# Handle error responses
|
|
268
|
+
if "error" in data:
|
|
269
|
+
raise RuntimeError(f"WhichModel error: {data['error']}")
|
|
270
|
+
|
|
271
|
+
recommended = data.get("recommended", {})
|
|
272
|
+
return {
|
|
273
|
+
"model_id": recommended.get("model_id", ""),
|
|
274
|
+
"provider": recommended.get("provider", ""),
|
|
275
|
+
"recommendation": recommended,
|
|
276
|
+
"alternative": data.get("alternative") or {},
|
|
277
|
+
"budget_model": data.get("budget_model") or {},
|
|
278
|
+
"confidence": data.get("confidence", "low"),
|
|
279
|
+
"data_freshness": data.get("data_freshness", ""),
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
def to_dict(self) -> dict[str, Any]:
|
|
283
|
+
"""Serialize this component to a dictionary."""
|
|
284
|
+
return default_to_dict(
|
|
285
|
+
self,
|
|
286
|
+
mcp_endpoint=self.mcp_endpoint,
|
|
287
|
+
timeout=self.timeout,
|
|
288
|
+
default_task_type=self.default_task_type,
|
|
289
|
+
default_complexity=self.default_complexity,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
@classmethod
|
|
293
|
+
def from_dict(cls, data: dict[str, Any]) -> WhichModelRouter:
|
|
294
|
+
"""Deserialize a component from a dictionary."""
|
|
295
|
+
return default_from_dict(cls, data)
|