pyagent-router 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyagent_router-0.1.0/.gitignore +18 -0
- pyagent_router-0.1.0/PKG-INFO +55 -0
- pyagent_router-0.1.0/README.md +30 -0
- pyagent_router-0.1.0/pyproject.toml +30 -0
- pyagent_router-0.1.0/src/pyagent_router/__init__.py +9 -0
- pyagent_router-0.1.0/src/pyagent_router/estimator.py +134 -0
- pyagent_router-0.1.0/src/pyagent_router/middleware.py +115 -0
- pyagent_router-0.1.0/src/pyagent_router/py.typed +0 -0
- pyagent_router-0.1.0/src/pyagent_router/scorer.py +121 -0
- pyagent_router-0.1.0/src/pyagent_router/selector.py +156 -0
- pyagent_router-0.1.0/tests/__init__.py +0 -0
- pyagent_router-0.1.0/tests/test_router.py +81 -0
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pyagent-router
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Difficulty-aware routing and model selection for multi-agent LLM workflows
|
|
5
|
+
License: MIT
|
|
6
|
+
Keywords: LLM,agents,cost-optimization,model-selection,routing
|
|
7
|
+
Classifier: Development Status :: 3 - Alpha
|
|
8
|
+
Classifier: Intended Audience :: Developers
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
14
|
+
Classifier: Typing :: Typed
|
|
15
|
+
Requires-Python: >=3.11
|
|
16
|
+
Requires-Dist: pyagent-patterns>=0.1.0
|
|
17
|
+
Provides-Extra: dev
|
|
18
|
+
Requires-Dist: mypy>=1.10; extra == 'dev'
|
|
19
|
+
Requires-Dist: pytest-asyncio>=0.23; extra == 'dev'
|
|
20
|
+
Requires-Dist: pytest>=8.0; extra == 'dev'
|
|
21
|
+
Requires-Dist: ruff>=0.5; extra == 'dev'
|
|
22
|
+
Provides-Extra: tiktoken
|
|
23
|
+
Requires-Dist: tiktoken>=0.7; extra == 'tiktoken'
|
|
24
|
+
Description-Content-Type: text/markdown
|
|
25
|
+
|
|
26
|
+
# pyagent-router
|
|
27
|
+
|
|
28
|
+
**Difficulty-aware routing and model selection** for multi-agent LLM workflows. Route easy tasks to cheap models, hard tasks to expensive ones.
|
|
29
|
+
|
|
30
|
+
## Install
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
pip install pyagent-router
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
## Components
|
|
37
|
+
|
|
38
|
+
- **DifficultyScorer** — Score task difficulty 1-10 based on heuristics
|
|
39
|
+
- **CostEstimator** — Estimate LLM call costs with built-in model pricing
|
|
40
|
+
- **ModelSelector** — Auto-select the cheapest viable model
|
|
41
|
+
- **RouterMiddleware** — Inject routing into agent calls
|
|
42
|
+
|
|
43
|
+
## Quick Example
|
|
44
|
+
|
|
45
|
+
```python
|
|
46
|
+
from pyagent_router import ModelSelector
|
|
47
|
+
|
|
48
|
+
result = ModelSelector().select("What is 2+2?")
|
|
49
|
+
print(f"{result.model}: ${result.cost_estimate.total_cost:.6f}")
|
|
50
|
+
# gpt-4.1-nano: $0.000002 (instead of $0.003 with gpt-4o)
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
## Typical Savings: 40-60%
|
|
54
|
+
|
|
55
|
+
For workloads where 70% of queries are easy, routing to cheap models saves 40-60% vs always using the most expensive model.
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# pyagent-router
|
|
2
|
+
|
|
3
|
+
**Difficulty-aware routing and model selection** for multi-agent LLM workflows. Route easy tasks to cheap models, hard tasks to expensive ones.
|
|
4
|
+
|
|
5
|
+
## Install
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
pip install pyagent-router
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## Components
|
|
12
|
+
|
|
13
|
+
- **DifficultyScorer** — Score task difficulty 1-10 based on heuristics
|
|
14
|
+
- **CostEstimator** — Estimate LLM call costs with built-in model pricing
|
|
15
|
+
- **ModelSelector** — Auto-select the cheapest viable model
|
|
16
|
+
- **RouterMiddleware** — Inject routing into agent calls
|
|
17
|
+
|
|
18
|
+
## Quick Example
|
|
19
|
+
|
|
20
|
+
```python
|
|
21
|
+
from pyagent_router import ModelSelector
|
|
22
|
+
|
|
23
|
+
result = ModelSelector().select("What is 2+2?")
|
|
24
|
+
print(f"{result.model}: ${result.cost_estimate.total_cost:.6f}")
|
|
25
|
+
# gpt-4.1-nano: $0.000002 (instead of $0.003 with gpt-4o)
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
## Typical Savings: 40-60%
|
|
29
|
+
|
|
30
|
+
For workloads where 70% of queries are easy, routing to cheap models saves 40-60% vs always using the most expensive model.
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "pyagent-router"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Difficulty-aware routing and model selection for multi-agent LLM workflows"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.11"
|
|
11
|
+
license = {text = "MIT"}
|
|
12
|
+
keywords = ["agents", "routing", "LLM", "cost-optimization", "model-selection"]
|
|
13
|
+
classifiers = [
|
|
14
|
+
"Development Status :: 3 - Alpha",
|
|
15
|
+
"Intended Audience :: Developers",
|
|
16
|
+
"License :: OSI Approved :: MIT License",
|
|
17
|
+
"Programming Language :: Python :: 3.11",
|
|
18
|
+
"Programming Language :: Python :: 3.12",
|
|
19
|
+
"Programming Language :: Python :: 3.13",
|
|
20
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
21
|
+
"Typing :: Typed",
|
|
22
|
+
]
|
|
23
|
+
dependencies = ["pyagent-patterns>=0.1.0"]
|
|
24
|
+
|
|
25
|
+
[project.optional-dependencies]
|
|
26
|
+
tiktoken = ["tiktoken>=0.7"]
|
|
27
|
+
dev = ["pytest>=8.0", "pytest-asyncio>=0.23", "ruff>=0.5", "mypy>=1.10"]
|
|
28
|
+
|
|
29
|
+
[tool.hatch.build.targets.wheel]
|
|
30
|
+
packages = ["src/pyagent_router"]
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""PyAgent Router — difficulty-aware routing and model selection for multi-agent LLM workflows."""
|
|
2
|
+
|
|
3
|
+
from pyagent_router.estimator import CostEstimator
|
|
4
|
+
from pyagent_router.middleware import RouterMiddleware
|
|
5
|
+
from pyagent_router.scorer import DifficultyScorer
|
|
6
|
+
from pyagent_router.selector import ModelSelector
|
|
7
|
+
|
|
8
|
+
__all__ = ["DifficultyScorer", "CostEstimator", "ModelSelector", "RouterMiddleware"]
|
|
9
|
+
__version__ = "0.1.0"
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
"""CostEstimator: estimate LLM call cost based on token count and model pricing.
|
|
2
|
+
|
|
3
|
+
Maintains a registry of model pricing (per 1M input/output tokens)
|
|
4
|
+
and estimates cost for a given task and model.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(frozen=True)
|
|
13
|
+
class ModelPricing:
|
|
14
|
+
"""Pricing for a model per 1M tokens."""
|
|
15
|
+
|
|
16
|
+
input_per_million: float
|
|
17
|
+
output_per_million: float
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class CostEstimate:
|
|
22
|
+
"""Estimated cost for a single LLM call.
|
|
23
|
+
|
|
24
|
+
Attributes:
|
|
25
|
+
model: Model name.
|
|
26
|
+
input_tokens: Estimated input tokens.
|
|
27
|
+
output_tokens: Estimated output tokens.
|
|
28
|
+
input_cost: Cost for input tokens in USD.
|
|
29
|
+
output_cost: Cost for output tokens in USD.
|
|
30
|
+
total_cost: Total estimated cost in USD.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
model: str
|
|
34
|
+
input_tokens: int
|
|
35
|
+
output_tokens: int
|
|
36
|
+
input_cost: float
|
|
37
|
+
output_cost: float
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def total_cost(self) -> float:
|
|
41
|
+
return self.input_cost + self.output_cost
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# Pricing as of mid-2025 (per 1M tokens)
|
|
45
|
+
DEFAULT_PRICING: dict[str, ModelPricing] = {
|
|
46
|
+
"gpt-4o": ModelPricing(2.50, 10.00),
|
|
47
|
+
"gpt-4o-mini": ModelPricing(0.15, 0.60),
|
|
48
|
+
"gpt-4.1": ModelPricing(2.00, 8.00),
|
|
49
|
+
"gpt-4.1-mini": ModelPricing(0.40, 1.60),
|
|
50
|
+
"gpt-4.1-nano": ModelPricing(0.10, 0.40),
|
|
51
|
+
"o3": ModelPricing(10.00, 40.00),
|
|
52
|
+
"o3-mini": ModelPricing(1.10, 4.40),
|
|
53
|
+
"o4-mini": ModelPricing(1.10, 4.40),
|
|
54
|
+
"claude-sonnet-4": ModelPricing(3.00, 15.00),
|
|
55
|
+
"claude-haiku-3.5": ModelPricing(0.80, 4.00),
|
|
56
|
+
"gemini-2.5-flash": ModelPricing(0.15, 0.60),
|
|
57
|
+
"gemini-2.5-pro": ModelPricing(1.25, 10.00),
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class CostEstimator:
|
|
62
|
+
"""Estimate LLM call costs based on model pricing registry.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
pricing: Optional custom pricing dict. Defaults to built-in pricing table.
|
|
66
|
+
default_output_ratio: Estimated output/input token ratio when output length unknown.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
pricing: dict[str, ModelPricing] | None = None,
|
|
72
|
+
default_output_ratio: float = 0.5,
|
|
73
|
+
) -> None:
|
|
74
|
+
self._pricing = pricing or dict(DEFAULT_PRICING)
|
|
75
|
+
self._output_ratio = default_output_ratio
|
|
76
|
+
|
|
77
|
+
def estimate(
|
|
78
|
+
self,
|
|
79
|
+
model: str,
|
|
80
|
+
input_tokens: int,
|
|
81
|
+
output_tokens: int | None = None,
|
|
82
|
+
) -> CostEstimate:
|
|
83
|
+
"""Estimate cost for a single LLM call.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
model: Model name (must be in pricing registry).
|
|
87
|
+
input_tokens: Number of input tokens.
|
|
88
|
+
output_tokens: Number of output tokens. If None, estimated from input.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
CostEstimate with breakdown.
|
|
92
|
+
|
|
93
|
+
Raises:
|
|
94
|
+
KeyError: If model not found in pricing registry.
|
|
95
|
+
"""
|
|
96
|
+
if model not in self._pricing:
|
|
97
|
+
raise KeyError(
|
|
98
|
+
f"Model '{model}' not in pricing registry. "
|
|
99
|
+
f"Available: {', '.join(sorted(self._pricing.keys()))}"
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
pricing = self._pricing[model]
|
|
103
|
+
est_output = output_tokens if output_tokens is not None else int(input_tokens * self._output_ratio)
|
|
104
|
+
|
|
105
|
+
return CostEstimate(
|
|
106
|
+
model=model,
|
|
107
|
+
input_tokens=input_tokens,
|
|
108
|
+
output_tokens=est_output,
|
|
109
|
+
input_cost=input_tokens * pricing.input_per_million / 1_000_000,
|
|
110
|
+
output_cost=est_output * pricing.output_per_million / 1_000_000,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def estimate_from_text(self, model: str, text: str) -> CostEstimate:
|
|
114
|
+
"""Estimate cost from raw text (approximates 4 chars per token)."""
|
|
115
|
+
input_tokens = len(text) // 4
|
|
116
|
+
return self.estimate(model, input_tokens)
|
|
117
|
+
|
|
118
|
+
def compare(self, text: str, models: list[str] | None = None) -> list[CostEstimate]:
|
|
119
|
+
"""Compare costs across multiple models for the same input.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
text: The input text.
|
|
123
|
+
models: Models to compare. Defaults to all registered models.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
List of CostEstimates sorted by total_cost ascending.
|
|
127
|
+
"""
|
|
128
|
+
target_models = models or list(self._pricing.keys())
|
|
129
|
+
estimates = [self.estimate_from_text(m, text) for m in target_models if m in self._pricing]
|
|
130
|
+
return sorted(estimates, key=lambda e: e.total_cost)
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def available_models(self) -> list[str]:
|
|
134
|
+
return sorted(self._pricing.keys())
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""RouterMiddleware: inject routing into any pattern's agent calls.
|
|
2
|
+
|
|
3
|
+
Wraps an Agent so that every call is automatically routed to the
|
|
4
|
+
optimal model based on task difficulty and cost.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from pyagent_patterns.base import Agent, LLMCallable, Message
|
|
10
|
+
from pyagent_router.selector import Capability, ModelSelector, SelectionResult
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class RoutedAgent(Agent):
|
|
14
|
+
"""An agent wrapper that routes each call through ModelSelector.
|
|
15
|
+
|
|
16
|
+
The routing decision is recorded in metadata for tracing.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
agent: The original agent.
|
|
20
|
+
selector: ModelSelector to use for routing decisions.
|
|
21
|
+
model_registry: Mapping of model names to LLM callables.
|
|
22
|
+
required_capability: Optional capability filter for model selection.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
agent: Agent,
|
|
28
|
+
selector: ModelSelector,
|
|
29
|
+
model_registry: dict[str, LLMCallable],
|
|
30
|
+
required_capability: Capability | None = None,
|
|
31
|
+
) -> None:
|
|
32
|
+
super().__init__(
|
|
33
|
+
name=agent.name,
|
|
34
|
+
llm=agent.llm,
|
|
35
|
+
system_prompt=agent.system_prompt,
|
|
36
|
+
description=agent.description,
|
|
37
|
+
)
|
|
38
|
+
self._original_agent = agent
|
|
39
|
+
self._selector = selector
|
|
40
|
+
self._model_registry = model_registry
|
|
41
|
+
self._required_capability = required_capability
|
|
42
|
+
self.routing_log: list[SelectionResult] = []
|
|
43
|
+
|
|
44
|
+
async def run(self, messages: list[Message]) -> Message:
|
|
45
|
+
"""Route the call to the optimal model, then execute."""
|
|
46
|
+
# Extract task text from messages for difficulty scoring
|
|
47
|
+
task_text = " ".join(m.content for m in messages if m.content)
|
|
48
|
+
|
|
49
|
+
selection = self._selector.select(task_text, self._required_capability)
|
|
50
|
+
self.routing_log.append(selection)
|
|
51
|
+
|
|
52
|
+
# Swap LLM to the selected model if available in registry
|
|
53
|
+
llm = self._model_registry.get(selection.model, self._original_agent.llm)
|
|
54
|
+
|
|
55
|
+
# Create a temporary agent with the routed LLM
|
|
56
|
+
routed = Agent(
|
|
57
|
+
name=self._original_agent.name,
|
|
58
|
+
llm=llm,
|
|
59
|
+
system_prompt=self._original_agent.system_prompt,
|
|
60
|
+
description=self._original_agent.description,
|
|
61
|
+
)
|
|
62
|
+
result = await routed.run(messages)
|
|
63
|
+
|
|
64
|
+
# Attach routing metadata to the message
|
|
65
|
+
result = Message(
|
|
66
|
+
role=result.role,
|
|
67
|
+
content=result.content,
|
|
68
|
+
name=result.name,
|
|
69
|
+
metadata={
|
|
70
|
+
**result.metadata,
|
|
71
|
+
"routed_model": selection.model,
|
|
72
|
+
"difficulty": selection.difficulty.score,
|
|
73
|
+
"estimated_cost": selection.cost_estimate.total_cost,
|
|
74
|
+
"reason": selection.reason,
|
|
75
|
+
},
|
|
76
|
+
)
|
|
77
|
+
return result
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class RouterMiddleware:
|
|
81
|
+
"""Middleware that wraps agents with automatic model routing.
|
|
82
|
+
|
|
83
|
+
Usage:
|
|
84
|
+
middleware = RouterMiddleware(model_registry={"gpt-4o": my_gpt4o, "gpt-4o-mini": my_mini})
|
|
85
|
+
routed_agent = middleware.wrap(my_agent)
|
|
86
|
+
# routed_agent now auto-selects model per call
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
model_registry: Mapping of model names to LLM callables.
|
|
90
|
+
selector: Optional ModelSelector. Created with defaults if None.
|
|
91
|
+
required_capability: Optional default capability filter.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
model_registry: dict[str, LLMCallable],
|
|
97
|
+
selector: ModelSelector | None = None,
|
|
98
|
+
required_capability: Capability | None = None,
|
|
99
|
+
) -> None:
|
|
100
|
+
self._registry = model_registry
|
|
101
|
+
self._selector = selector or ModelSelector()
|
|
102
|
+
self._capability = required_capability
|
|
103
|
+
|
|
104
|
+
def wrap(self, agent: Agent) -> RoutedAgent:
|
|
105
|
+
"""Wrap an agent with routing capabilities."""
|
|
106
|
+
return RoutedAgent(
|
|
107
|
+
agent=agent,
|
|
108
|
+
selector=self._selector,
|
|
109
|
+
model_registry=self._registry,
|
|
110
|
+
required_capability=self._capability,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def wrap_all(self, agents: list[Agent]) -> list[RoutedAgent]:
|
|
114
|
+
"""Wrap multiple agents."""
|
|
115
|
+
return [self.wrap(a) for a in agents]
|
|
File without changes
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""DifficultyScorer: estimate task difficulty to inform model routing.
|
|
2
|
+
|
|
3
|
+
Uses heuristics (token count, keyword complexity, question type) and
|
|
4
|
+
optionally an LLM classifier to score difficulty 1-10.
|
|
5
|
+
|
|
6
|
+
Based on: arxiv:2509.11079 "Difficulty-Aware Agent Orchestration"
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import re
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class DifficultyScore:
|
|
18
|
+
"""Result of difficulty scoring.
|
|
19
|
+
|
|
20
|
+
Attributes:
|
|
21
|
+
score: Difficulty score from 1 (trivial) to 10 (extremely hard).
|
|
22
|
+
signals: Dictionary of individual signal scores that contributed.
|
|
23
|
+
category: Human-readable difficulty category.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
score: int
|
|
27
|
+
signals: dict[str, float] = field(default_factory=dict)
|
|
28
|
+
category: str = ""
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def is_easy(self) -> bool:
|
|
32
|
+
return self.score <= 3
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def is_medium(self) -> bool:
|
|
36
|
+
return 4 <= self.score <= 6
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def is_hard(self) -> bool:
|
|
40
|
+
return self.score >= 7
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# Keywords indicating higher complexity
|
|
44
|
+
_COMPLEX_KEYWORDS = {
|
|
45
|
+
"analyze", "compare", "contrast", "evaluate", "synthesize",
|
|
46
|
+
"design", "architect", "optimize", "debug", "prove",
|
|
47
|
+
"multi-step", "trade-off", "implications", "comprehensive",
|
|
48
|
+
"algorithm", "mathematical", "formal", "derivation",
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
_SIMPLE_KEYWORDS = {
|
|
52
|
+
"what is", "define", "list", "name", "when was",
|
|
53
|
+
"who is", "how many", "yes or no", "true or false",
|
|
54
|
+
"translate", "convert", "summarize",
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class DifficultyScorer:
|
|
59
|
+
"""Heuristic-based task difficulty scorer.
|
|
60
|
+
|
|
61
|
+
Scores tasks on a 1-10 scale using multiple signals:
|
|
62
|
+
- Token length
|
|
63
|
+
- Keyword complexity
|
|
64
|
+
- Question structure
|
|
65
|
+
- Required reasoning depth
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
custom_signals: Optional dict of custom signal functions.
|
|
69
|
+
Each function takes a task string and returns a float 0-1.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
custom_signals: dict[str, Any] | None = None,
|
|
75
|
+
) -> None:
|
|
76
|
+
self._custom_signals = custom_signals or {}
|
|
77
|
+
|
|
78
|
+
def score(self, task: str) -> DifficultyScore:
|
|
79
|
+
"""Score the difficulty of a task string."""
|
|
80
|
+
signals: dict[str, float] = {}
|
|
81
|
+
|
|
82
|
+
# Signal 1: Length complexity (longer = harder, up to a point)
|
|
83
|
+
word_count = len(task.split())
|
|
84
|
+
signals["length"] = min(word_count / 200, 1.0)
|
|
85
|
+
|
|
86
|
+
# Signal 2: Keyword complexity
|
|
87
|
+
task_lower = task.lower()
|
|
88
|
+
complex_hits = sum(1 for kw in _COMPLEX_KEYWORDS if kw in task_lower)
|
|
89
|
+
simple_hits = sum(1 for kw in _SIMPLE_KEYWORDS if kw in task_lower)
|
|
90
|
+
signals["keywords"] = min(complex_hits / 3, 1.0) - min(simple_hits / 3, 0.5)
|
|
91
|
+
signals["keywords"] = max(0.0, signals["keywords"])
|
|
92
|
+
|
|
93
|
+
# Signal 3: Multi-part questions (numbered steps, multiple questions)
|
|
94
|
+
multi_part = len(re.findall(r"\d+\.", task)) + task.count("?") - 1
|
|
95
|
+
signals["multi_part"] = min(max(multi_part, 0) / 5, 1.0)
|
|
96
|
+
|
|
97
|
+
# Signal 4: Code/math indicators
|
|
98
|
+
code_indicators = sum(1 for marker in ["```", "def ", "class ", "function", "import"]
|
|
99
|
+
if marker in task)
|
|
100
|
+
math_indicators = sum(1 for marker in ["∑", "∫", "equation", "formula", "proof"]
|
|
101
|
+
if marker in task_lower)
|
|
102
|
+
signals["technical"] = min((code_indicators + math_indicators) / 3, 1.0)
|
|
103
|
+
|
|
104
|
+
# Custom signals
|
|
105
|
+
for name, fn in self._custom_signals.items():
|
|
106
|
+
signals[name] = float(fn(task))
|
|
107
|
+
|
|
108
|
+
# Weighted average → 1-10 scale
|
|
109
|
+
weights = {"length": 0.15, "keywords": 0.35, "multi_part": 0.25, "technical": 0.25}
|
|
110
|
+
weighted_sum = sum(signals.get(k, 0) * w for k, w in weights.items())
|
|
111
|
+
|
|
112
|
+
# Add custom signals with equal weight
|
|
113
|
+
if self._custom_signals:
|
|
114
|
+
custom_avg = sum(signals.get(k, 0) for k in self._custom_signals) / len(self._custom_signals)
|
|
115
|
+
weighted_sum = weighted_sum * 0.7 + custom_avg * 0.3
|
|
116
|
+
|
|
117
|
+
raw_score = max(1, min(10, int(weighted_sum * 10) + 1))
|
|
118
|
+
|
|
119
|
+
category = "easy" if raw_score <= 3 else "medium" if raw_score <= 6 else "hard"
|
|
120
|
+
|
|
121
|
+
return DifficultyScore(score=raw_score, signals=signals, category=category)
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
"""ModelSelector: pick the best model based on difficulty, cost, and capability constraints.
|
|
2
|
+
|
|
3
|
+
Combines DifficultyScorer output with CostEstimator to select the
|
|
4
|
+
cheapest model that meets quality requirements.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from enum import Enum
|
|
11
|
+
|
|
12
|
+
from pyagent_router.estimator import CostEstimate, CostEstimator
|
|
13
|
+
from pyagent_router.scorer import DifficultyScore, DifficultyScorer
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Capability(str, Enum):
|
|
17
|
+
"""Model capabilities for filtering."""
|
|
18
|
+
|
|
19
|
+
CODE = "code"
|
|
20
|
+
MATH = "math"
|
|
21
|
+
REASONING = "reasoning"
|
|
22
|
+
CREATIVE = "creative"
|
|
23
|
+
GENERAL = "general"
|
|
24
|
+
VISION = "vision"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class ModelSpec:
|
|
29
|
+
"""Specification of a model's capabilities and constraints.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
name: Model identifier (must match pricing registry).
|
|
33
|
+
min_difficulty: Minimum difficulty score this model should handle.
|
|
34
|
+
max_difficulty: Maximum difficulty score this model should handle.
|
|
35
|
+
capabilities: Set of capabilities this model excels at.
|
|
36
|
+
max_context: Maximum context window in tokens.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
name: str
|
|
40
|
+
min_difficulty: int = 1
|
|
41
|
+
max_difficulty: int = 10
|
|
42
|
+
capabilities: set[Capability] = field(default_factory=lambda: {Capability.GENERAL})
|
|
43
|
+
max_context: int = 128_000
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
DEFAULT_MODEL_SPECS: list[ModelSpec] = [
|
|
47
|
+
ModelSpec("gpt-4.1-nano", 1, 3, {Capability.GENERAL}, 1_000_000),
|
|
48
|
+
ModelSpec("gpt-4o-mini", 1, 5, {Capability.GENERAL, Capability.CODE}, 128_000),
|
|
49
|
+
ModelSpec("gpt-4.1-mini", 1, 6, {Capability.GENERAL, Capability.CODE, Capability.REASONING}, 1_000_000),
|
|
50
|
+
ModelSpec("gpt-4o", 3, 8, {Capability.GENERAL, Capability.CODE, Capability.CREATIVE, Capability.VISION}, 128_000),
|
|
51
|
+
ModelSpec("gpt-4.1", 4, 9, {Capability.GENERAL, Capability.CODE, Capability.REASONING}, 1_000_000),
|
|
52
|
+
ModelSpec("claude-sonnet-4", 5, 10, {Capability.GENERAL, Capability.CODE, Capability.REASONING, Capability.CREATIVE}, 200_000),
|
|
53
|
+
ModelSpec("o3-mini", 6, 10, {Capability.REASONING, Capability.MATH, Capability.CODE}, 200_000),
|
|
54
|
+
ModelSpec("o3", 8, 10, {Capability.REASONING, Capability.MATH, Capability.CODE}, 200_000),
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass(frozen=True)
|
|
59
|
+
class SelectionResult:
|
|
60
|
+
"""Result of model selection.
|
|
61
|
+
|
|
62
|
+
Attributes:
|
|
63
|
+
model: Selected model name.
|
|
64
|
+
difficulty: The difficulty assessment.
|
|
65
|
+
cost_estimate: Estimated cost for this model.
|
|
66
|
+
reason: Human-readable explanation of why this model was chosen.
|
|
67
|
+
alternatives: Other models that were considered.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
model: str
|
|
71
|
+
difficulty: DifficultyScore
|
|
72
|
+
cost_estimate: CostEstimate
|
|
73
|
+
reason: str
|
|
74
|
+
alternatives: list[str] = field(default_factory=list)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class ModelSelector:
|
|
78
|
+
"""Select the optimal model based on task difficulty and cost.
|
|
79
|
+
|
|
80
|
+
Strategy: find the cheapest model whose difficulty range covers the task
|
|
81
|
+
and whose capabilities match the required capability (if specified).
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
specs: List of ModelSpec definitions. Defaults to built-in specs.
|
|
85
|
+
cost_estimator: CostEstimator instance. Created automatically if None.
|
|
86
|
+
scorer: DifficultyScorer instance. Created automatically if None.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
specs: list[ModelSpec] | None = None,
|
|
92
|
+
cost_estimator: CostEstimator | None = None,
|
|
93
|
+
scorer: DifficultyScorer | None = None,
|
|
94
|
+
) -> None:
|
|
95
|
+
self._specs = specs or list(DEFAULT_MODEL_SPECS)
|
|
96
|
+
self._estimator = cost_estimator or CostEstimator()
|
|
97
|
+
self._scorer = scorer or DifficultyScorer()
|
|
98
|
+
|
|
99
|
+
def select(
|
|
100
|
+
self,
|
|
101
|
+
task: str,
|
|
102
|
+
required_capability: Capability | None = None,
|
|
103
|
+
) -> SelectionResult:
|
|
104
|
+
"""Select the best model for a given task.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
task: The task text to analyze.
|
|
108
|
+
required_capability: Optional capability filter.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
SelectionResult with chosen model and reasoning.
|
|
112
|
+
"""
|
|
113
|
+
difficulty = self._scorer.score(task)
|
|
114
|
+
|
|
115
|
+
# Filter specs by difficulty range and capability
|
|
116
|
+
candidates: list[ModelSpec] = []
|
|
117
|
+
for spec in self._specs:
|
|
118
|
+
if spec.min_difficulty <= difficulty.score <= spec.max_difficulty:
|
|
119
|
+
if required_capability is None or required_capability in spec.capabilities:
|
|
120
|
+
candidates.append(spec)
|
|
121
|
+
|
|
122
|
+
if not candidates:
|
|
123
|
+
# Fallback to the most capable model
|
|
124
|
+
candidates = [self._specs[-1]]
|
|
125
|
+
|
|
126
|
+
# Estimate cost for each candidate and pick cheapest
|
|
127
|
+
scored: list[tuple[ModelSpec, CostEstimate]] = []
|
|
128
|
+
for spec in candidates:
|
|
129
|
+
try:
|
|
130
|
+
estimate = self._estimator.estimate_from_text(spec.name, task)
|
|
131
|
+
scored.append((spec, estimate))
|
|
132
|
+
except KeyError:
|
|
133
|
+
continue
|
|
134
|
+
|
|
135
|
+
if not scored:
|
|
136
|
+
# Last resort fallback
|
|
137
|
+
spec = self._specs[0]
|
|
138
|
+
estimate = CostEstimate(spec.name, len(task) // 4, len(task) // 8, 0.0, 0.0)
|
|
139
|
+
scored = [(spec, estimate)]
|
|
140
|
+
|
|
141
|
+
scored.sort(key=lambda x: x[1].total_cost)
|
|
142
|
+
chosen_spec, chosen_cost = scored[0]
|
|
143
|
+
alternatives = [s.name for s, _ in scored[1:]]
|
|
144
|
+
|
|
145
|
+
reason = (
|
|
146
|
+
f"Difficulty {difficulty.score}/10 ({difficulty.category}) → "
|
|
147
|
+
f"{chosen_spec.name} (cheapest candidate at ${chosen_cost.total_cost:.6f})"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
return SelectionResult(
|
|
151
|
+
model=chosen_spec.name,
|
|
152
|
+
difficulty=difficulty,
|
|
153
|
+
cost_estimate=chosen_cost,
|
|
154
|
+
reason=reason,
|
|
155
|
+
alternatives=alternatives,
|
|
156
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Tests for pyagent-router."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from pyagent_patterns.base import Agent, MockLLM
|
|
8
|
+
from pyagent_router.estimator import CostEstimator
|
|
9
|
+
from pyagent_router.middleware import RouterMiddleware
|
|
10
|
+
from pyagent_router.scorer import DifficultyScorer
|
|
11
|
+
from pyagent_router.selector import Capability, ModelSelector
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def test_difficulty_scorer_easy():
|
|
15
|
+
scorer = DifficultyScorer()
|
|
16
|
+
result = scorer.score("What is 2 + 2?")
|
|
17
|
+
assert result.is_easy
|
|
18
|
+
assert result.category == "easy"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def test_difficulty_scorer_hard():
|
|
22
|
+
scorer = DifficultyScorer()
|
|
23
|
+
result = scorer.score(
|
|
24
|
+
"Analyze and compare the trade-offs between microservice and monolithic "
|
|
25
|
+
"architectures. Design an optimal system that synthesizes the benefits of both. "
|
|
26
|
+
"1. Evaluate scalability. 2. Evaluate maintainability. 3. Evaluate cost. "
|
|
27
|
+
"4. Propose a hybrid architecture. 5. Prove its optimality."
|
|
28
|
+
)
|
|
29
|
+
assert result.score >= 4 # Should be medium or hard
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def test_cost_estimator_basic():
|
|
33
|
+
estimator = CostEstimator()
|
|
34
|
+
estimate = estimator.estimate("gpt-4o-mini", input_tokens=1000, output_tokens=500)
|
|
35
|
+
assert estimate.model == "gpt-4o-mini"
|
|
36
|
+
assert estimate.input_cost == 1000 * 0.15 / 1_000_000
|
|
37
|
+
assert estimate.output_cost == 500 * 0.60 / 1_000_000
|
|
38
|
+
assert estimate.total_cost == estimate.input_cost + estimate.output_cost
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def test_cost_estimator_compare():
|
|
42
|
+
estimator = CostEstimator()
|
|
43
|
+
estimates = estimator.compare("Short text", models=["gpt-4o", "gpt-4o-mini"])
|
|
44
|
+
assert len(estimates) == 2
|
|
45
|
+
assert estimates[0].total_cost <= estimates[1].total_cost # sorted ascending
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def test_model_selector_easy_task():
|
|
49
|
+
selector = ModelSelector()
|
|
50
|
+
result = selector.select("What is the capital of France?")
|
|
51
|
+
# Easy task should select a cheap model
|
|
52
|
+
assert result.difficulty.is_easy or result.difficulty.is_medium
|
|
53
|
+
assert result.model in ["gpt-4.1-nano", "gpt-4o-mini", "gpt-4.1-mini"]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def test_model_selector_with_capability():
|
|
57
|
+
selector = ModelSelector()
|
|
58
|
+
result = selector.select(
|
|
59
|
+
"Write a complex algorithm for graph traversal",
|
|
60
|
+
required_capability=Capability.CODE,
|
|
61
|
+
)
|
|
62
|
+
assert Capability.CODE in [
|
|
63
|
+
cap for spec in selector._specs if spec.name == result.model for cap in spec.capabilities
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@pytest.mark.asyncio
|
|
68
|
+
async def test_router_middleware():
|
|
69
|
+
cheap_llm = MockLLM(responses=["Cheap model response"])
|
|
70
|
+
expensive_llm = MockLLM(responses=["Expensive model response"])
|
|
71
|
+
|
|
72
|
+
middleware = RouterMiddleware(
|
|
73
|
+
model_registry={"gpt-4o-mini": cheap_llm, "gpt-4o": expensive_llm},
|
|
74
|
+
)
|
|
75
|
+
agent = Agent("test_agent", cheap_llm)
|
|
76
|
+
routed = middleware.wrap(agent)
|
|
77
|
+
|
|
78
|
+
from pyagent_patterns.base import Message
|
|
79
|
+
result = await routed.run([Message.user("What is 1+1?")])
|
|
80
|
+
assert result.metadata.get("routed_model") is not None
|
|
81
|
+
assert len(routed.routing_log) == 1
|