pyagent-router 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,18 @@
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+ *.egg-info/
5
+ dist/
6
+ build/
7
+ .eggs/
8
+ *.egg
9
+ .venv/
10
+ venv/
11
+ .mypy_cache/
12
+ .ruff_cache/
13
+ .pytest_cache/
14
+ .coverage
15
+ htmlcov/
16
+ site/
17
+ .env
18
+ *.log
@@ -0,0 +1,55 @@
1
+ Metadata-Version: 2.4
2
+ Name: pyagent-router
3
+ Version: 0.1.0
4
+ Summary: Difficulty-aware routing and model selection for multi-agent LLM workflows
5
+ License: MIT
6
+ Keywords: LLM,agents,cost-optimization,model-selection,routing
7
+ Classifier: Development Status :: 3 - Alpha
8
+ Classifier: Intended Audience :: Developers
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Programming Language :: Python :: 3.11
11
+ Classifier: Programming Language :: Python :: 3.12
12
+ Classifier: Programming Language :: Python :: 3.13
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Classifier: Typing :: Typed
15
+ Requires-Python: >=3.11
16
+ Requires-Dist: pyagent-patterns>=0.1.0
17
+ Provides-Extra: dev
18
+ Requires-Dist: mypy>=1.10; extra == 'dev'
19
+ Requires-Dist: pytest-asyncio>=0.23; extra == 'dev'
20
+ Requires-Dist: pytest>=8.0; extra == 'dev'
21
+ Requires-Dist: ruff>=0.5; extra == 'dev'
22
+ Provides-Extra: tiktoken
23
+ Requires-Dist: tiktoken>=0.7; extra == 'tiktoken'
24
+ Description-Content-Type: text/markdown
25
+
26
+ # pyagent-router
27
+
28
+ **Difficulty-aware routing and model selection** for multi-agent LLM workflows. Route easy tasks to cheap models, hard tasks to expensive ones.
29
+
30
+ ## Install
31
+
32
+ ```bash
33
+ pip install pyagent-router
34
+ ```
35
+
36
+ ## Components
37
+
38
+ - **DifficultyScorer** — Score task difficulty 1-10 based on heuristics
39
+ - **CostEstimator** — Estimate LLM call costs with built-in model pricing
40
+ - **ModelSelector** — Auto-select the cheapest viable model
41
+ - **RouterMiddleware** — Inject routing into agent calls
42
+
43
+ ## Quick Example
44
+
45
+ ```python
46
+ from pyagent_router import ModelSelector
47
+
48
+ result = ModelSelector().select("What is 2+2?")
49
+ print(f"{result.model}: ${result.cost_estimate.total_cost:.6f}")
50
+ # gpt-4.1-nano: $0.000002 (instead of $0.003 with gpt-4o)
51
+ ```
52
+
53
+ ## Typical Savings: 40-60%
54
+
55
+ For workloads where 70% of queries are easy, routing to cheap models saves 40-60% vs always using the most expensive model.
@@ -0,0 +1,30 @@
1
+ # pyagent-router
2
+
3
+ **Difficulty-aware routing and model selection** for multi-agent LLM workflows. Route easy tasks to cheap models, hard tasks to expensive ones.
4
+
5
+ ## Install
6
+
7
+ ```bash
8
+ pip install pyagent-router
9
+ ```
10
+
11
+ ## Components
12
+
13
+ - **DifficultyScorer** — Score task difficulty 1-10 based on heuristics
14
+ - **CostEstimator** — Estimate LLM call costs with built-in model pricing
15
+ - **ModelSelector** — Auto-select the cheapest viable model
16
+ - **RouterMiddleware** — Inject routing into agent calls
17
+
18
+ ## Quick Example
19
+
20
+ ```python
21
+ from pyagent_router import ModelSelector
22
+
23
+ result = ModelSelector().select("What is 2+2?")
24
+ print(f"{result.model}: ${result.cost_estimate.total_cost:.6f}")
25
+ # gpt-4.1-nano: $0.000002 (instead of $0.003 with gpt-4o)
26
+ ```
27
+
28
+ ## Typical Savings: 40-60%
29
+
30
+ For workloads where 70% of queries are easy, routing to cheap models saves 40-60% vs always using the most expensive model.
@@ -0,0 +1,30 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "pyagent-router"
7
+ version = "0.1.0"
8
+ description = "Difficulty-aware routing and model selection for multi-agent LLM workflows"
9
+ readme = "README.md"
10
+ requires-python = ">=3.11"
11
+ license = {text = "MIT"}
12
+ keywords = ["agents", "routing", "LLM", "cost-optimization", "model-selection"]
13
+ classifiers = [
14
+ "Development Status :: 3 - Alpha",
15
+ "Intended Audience :: Developers",
16
+ "License :: OSI Approved :: MIT License",
17
+ "Programming Language :: Python :: 3.11",
18
+ "Programming Language :: Python :: 3.12",
19
+ "Programming Language :: Python :: 3.13",
20
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
21
+ "Typing :: Typed",
22
+ ]
23
+ dependencies = ["pyagent-patterns>=0.1.0"]
24
+
25
+ [project.optional-dependencies]
26
+ tiktoken = ["tiktoken>=0.7"]
27
+ dev = ["pytest>=8.0", "pytest-asyncio>=0.23", "ruff>=0.5", "mypy>=1.10"]
28
+
29
+ [tool.hatch.build.targets.wheel]
30
+ packages = ["src/pyagent_router"]
@@ -0,0 +1,9 @@
1
+ """PyAgent Router — difficulty-aware routing and model selection for multi-agent LLM workflows."""
2
+
3
+ from pyagent_router.estimator import CostEstimator
4
+ from pyagent_router.middleware import RouterMiddleware
5
+ from pyagent_router.scorer import DifficultyScorer
6
+ from pyagent_router.selector import ModelSelector
7
+
8
+ __all__ = ["DifficultyScorer", "CostEstimator", "ModelSelector", "RouterMiddleware"]
9
+ __version__ = "0.1.0"
@@ -0,0 +1,134 @@
1
+ """CostEstimator: estimate LLM call cost based on token count and model pricing.
2
+
3
+ Maintains a registry of model pricing (per 1M input/output tokens)
4
+ and estimates cost for a given task and model.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class ModelPricing:
14
+ """Pricing for a model per 1M tokens."""
15
+
16
+ input_per_million: float
17
+ output_per_million: float
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class CostEstimate:
22
+ """Estimated cost for a single LLM call.
23
+
24
+ Attributes:
25
+ model: Model name.
26
+ input_tokens: Estimated input tokens.
27
+ output_tokens: Estimated output tokens.
28
+ input_cost: Cost for input tokens in USD.
29
+ output_cost: Cost for output tokens in USD.
30
+ total_cost: Total estimated cost in USD.
31
+ """
32
+
33
+ model: str
34
+ input_tokens: int
35
+ output_tokens: int
36
+ input_cost: float
37
+ output_cost: float
38
+
39
+ @property
40
+ def total_cost(self) -> float:
41
+ return self.input_cost + self.output_cost
42
+
43
+
44
+ # Pricing as of mid-2025 (per 1M tokens)
45
+ DEFAULT_PRICING: dict[str, ModelPricing] = {
46
+ "gpt-4o": ModelPricing(2.50, 10.00),
47
+ "gpt-4o-mini": ModelPricing(0.15, 0.60),
48
+ "gpt-4.1": ModelPricing(2.00, 8.00),
49
+ "gpt-4.1-mini": ModelPricing(0.40, 1.60),
50
+ "gpt-4.1-nano": ModelPricing(0.10, 0.40),
51
+ "o3": ModelPricing(10.00, 40.00),
52
+ "o3-mini": ModelPricing(1.10, 4.40),
53
+ "o4-mini": ModelPricing(1.10, 4.40),
54
+ "claude-sonnet-4": ModelPricing(3.00, 15.00),
55
+ "claude-haiku-3.5": ModelPricing(0.80, 4.00),
56
+ "gemini-2.5-flash": ModelPricing(0.15, 0.60),
57
+ "gemini-2.5-pro": ModelPricing(1.25, 10.00),
58
+ }
59
+
60
+
61
+ class CostEstimator:
62
+ """Estimate LLM call costs based on model pricing registry.
63
+
64
+ Args:
65
+ pricing: Optional custom pricing dict. Defaults to built-in pricing table.
66
+ default_output_ratio: Estimated output/input token ratio when output length unknown.
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ pricing: dict[str, ModelPricing] | None = None,
72
+ default_output_ratio: float = 0.5,
73
+ ) -> None:
74
+ self._pricing = pricing or dict(DEFAULT_PRICING)
75
+ self._output_ratio = default_output_ratio
76
+
77
+ def estimate(
78
+ self,
79
+ model: str,
80
+ input_tokens: int,
81
+ output_tokens: int | None = None,
82
+ ) -> CostEstimate:
83
+ """Estimate cost for a single LLM call.
84
+
85
+ Args:
86
+ model: Model name (must be in pricing registry).
87
+ input_tokens: Number of input tokens.
88
+ output_tokens: Number of output tokens. If None, estimated from input.
89
+
90
+ Returns:
91
+ CostEstimate with breakdown.
92
+
93
+ Raises:
94
+ KeyError: If model not found in pricing registry.
95
+ """
96
+ if model not in self._pricing:
97
+ raise KeyError(
98
+ f"Model '{model}' not in pricing registry. "
99
+ f"Available: {', '.join(sorted(self._pricing.keys()))}"
100
+ )
101
+
102
+ pricing = self._pricing[model]
103
+ est_output = output_tokens if output_tokens is not None else int(input_tokens * self._output_ratio)
104
+
105
+ return CostEstimate(
106
+ model=model,
107
+ input_tokens=input_tokens,
108
+ output_tokens=est_output,
109
+ input_cost=input_tokens * pricing.input_per_million / 1_000_000,
110
+ output_cost=est_output * pricing.output_per_million / 1_000_000,
111
+ )
112
+
113
+ def estimate_from_text(self, model: str, text: str) -> CostEstimate:
114
+ """Estimate cost from raw text (approximates 4 chars per token)."""
115
+ input_tokens = len(text) // 4
116
+ return self.estimate(model, input_tokens)
117
+
118
+ def compare(self, text: str, models: list[str] | None = None) -> list[CostEstimate]:
119
+ """Compare costs across multiple models for the same input.
120
+
121
+ Args:
122
+ text: The input text.
123
+ models: Models to compare. Defaults to all registered models.
124
+
125
+ Returns:
126
+ List of CostEstimates sorted by total_cost ascending.
127
+ """
128
+ target_models = models or list(self._pricing.keys())
129
+ estimates = [self.estimate_from_text(m, text) for m in target_models if m in self._pricing]
130
+ return sorted(estimates, key=lambda e: e.total_cost)
131
+
132
+ @property
133
+ def available_models(self) -> list[str]:
134
+ return sorted(self._pricing.keys())
@@ -0,0 +1,115 @@
1
+ """RouterMiddleware: inject routing into any pattern's agent calls.
2
+
3
+ Wraps an Agent so that every call is automatically routed to the
4
+ optimal model based on task difficulty and cost.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from pyagent_patterns.base import Agent, LLMCallable, Message
10
+ from pyagent_router.selector import Capability, ModelSelector, SelectionResult
11
+
12
+
13
+ class RoutedAgent(Agent):
14
+ """An agent wrapper that routes each call through ModelSelector.
15
+
16
+ The routing decision is recorded in metadata for tracing.
17
+
18
+ Args:
19
+ agent: The original agent.
20
+ selector: ModelSelector to use for routing decisions.
21
+ model_registry: Mapping of model names to LLM callables.
22
+ required_capability: Optional capability filter for model selection.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ agent: Agent,
28
+ selector: ModelSelector,
29
+ model_registry: dict[str, LLMCallable],
30
+ required_capability: Capability | None = None,
31
+ ) -> None:
32
+ super().__init__(
33
+ name=agent.name,
34
+ llm=agent.llm,
35
+ system_prompt=agent.system_prompt,
36
+ description=agent.description,
37
+ )
38
+ self._original_agent = agent
39
+ self._selector = selector
40
+ self._model_registry = model_registry
41
+ self._required_capability = required_capability
42
+ self.routing_log: list[SelectionResult] = []
43
+
44
+ async def run(self, messages: list[Message]) -> Message:
45
+ """Route the call to the optimal model, then execute."""
46
+ # Extract task text from messages for difficulty scoring
47
+ task_text = " ".join(m.content for m in messages if m.content)
48
+
49
+ selection = self._selector.select(task_text, self._required_capability)
50
+ self.routing_log.append(selection)
51
+
52
+ # Swap LLM to the selected model if available in registry
53
+ llm = self._model_registry.get(selection.model, self._original_agent.llm)
54
+
55
+ # Create a temporary agent with the routed LLM
56
+ routed = Agent(
57
+ name=self._original_agent.name,
58
+ llm=llm,
59
+ system_prompt=self._original_agent.system_prompt,
60
+ description=self._original_agent.description,
61
+ )
62
+ result = await routed.run(messages)
63
+
64
+ # Attach routing metadata to the message
65
+ result = Message(
66
+ role=result.role,
67
+ content=result.content,
68
+ name=result.name,
69
+ metadata={
70
+ **result.metadata,
71
+ "routed_model": selection.model,
72
+ "difficulty": selection.difficulty.score,
73
+ "estimated_cost": selection.cost_estimate.total_cost,
74
+ "reason": selection.reason,
75
+ },
76
+ )
77
+ return result
78
+
79
+
80
+ class RouterMiddleware:
81
+ """Middleware that wraps agents with automatic model routing.
82
+
83
+ Usage:
84
+ middleware = RouterMiddleware(model_registry={"gpt-4o": my_gpt4o, "gpt-4o-mini": my_mini})
85
+ routed_agent = middleware.wrap(my_agent)
86
+ # routed_agent now auto-selects model per call
87
+
88
+ Args:
89
+ model_registry: Mapping of model names to LLM callables.
90
+ selector: Optional ModelSelector. Created with defaults if None.
91
+ required_capability: Optional default capability filter.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ model_registry: dict[str, LLMCallable],
97
+ selector: ModelSelector | None = None,
98
+ required_capability: Capability | None = None,
99
+ ) -> None:
100
+ self._registry = model_registry
101
+ self._selector = selector or ModelSelector()
102
+ self._capability = required_capability
103
+
104
+ def wrap(self, agent: Agent) -> RoutedAgent:
105
+ """Wrap an agent with routing capabilities."""
106
+ return RoutedAgent(
107
+ agent=agent,
108
+ selector=self._selector,
109
+ model_registry=self._registry,
110
+ required_capability=self._capability,
111
+ )
112
+
113
+ def wrap_all(self, agents: list[Agent]) -> list[RoutedAgent]:
114
+ """Wrap multiple agents."""
115
+ return [self.wrap(a) for a in agents]
File without changes
@@ -0,0 +1,121 @@
1
+ """DifficultyScorer: estimate task difficulty to inform model routing.
2
+
3
+ Uses heuristics (token count, keyword complexity, question type) and
4
+ optionally an LLM classifier to score difficulty 1-10.
5
+
6
+ Based on: arxiv:2509.11079 "Difficulty-Aware Agent Orchestration"
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import re
12
+ from dataclasses import dataclass, field
13
+ from typing import Any
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class DifficultyScore:
18
+ """Result of difficulty scoring.
19
+
20
+ Attributes:
21
+ score: Difficulty score from 1 (trivial) to 10 (extremely hard).
22
+ signals: Dictionary of individual signal scores that contributed.
23
+ category: Human-readable difficulty category.
24
+ """
25
+
26
+ score: int
27
+ signals: dict[str, float] = field(default_factory=dict)
28
+ category: str = ""
29
+
30
+ @property
31
+ def is_easy(self) -> bool:
32
+ return self.score <= 3
33
+
34
+ @property
35
+ def is_medium(self) -> bool:
36
+ return 4 <= self.score <= 6
37
+
38
+ @property
39
+ def is_hard(self) -> bool:
40
+ return self.score >= 7
41
+
42
+
43
+ # Keywords indicating higher complexity
44
+ _COMPLEX_KEYWORDS = {
45
+ "analyze", "compare", "contrast", "evaluate", "synthesize",
46
+ "design", "architect", "optimize", "debug", "prove",
47
+ "multi-step", "trade-off", "implications", "comprehensive",
48
+ "algorithm", "mathematical", "formal", "derivation",
49
+ }
50
+
51
+ _SIMPLE_KEYWORDS = {
52
+ "what is", "define", "list", "name", "when was",
53
+ "who is", "how many", "yes or no", "true or false",
54
+ "translate", "convert", "summarize",
55
+ }
56
+
57
+
58
+ class DifficultyScorer:
59
+ """Heuristic-based task difficulty scorer.
60
+
61
+ Scores tasks on a 1-10 scale using multiple signals:
62
+ - Token length
63
+ - Keyword complexity
64
+ - Question structure
65
+ - Required reasoning depth
66
+
67
+ Args:
68
+ custom_signals: Optional dict of custom signal functions.
69
+ Each function takes a task string and returns a float 0-1.
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ custom_signals: dict[str, Any] | None = None,
75
+ ) -> None:
76
+ self._custom_signals = custom_signals or {}
77
+
78
+ def score(self, task: str) -> DifficultyScore:
79
+ """Score the difficulty of a task string."""
80
+ signals: dict[str, float] = {}
81
+
82
+ # Signal 1: Length complexity (longer = harder, up to a point)
83
+ word_count = len(task.split())
84
+ signals["length"] = min(word_count / 200, 1.0)
85
+
86
+ # Signal 2: Keyword complexity
87
+ task_lower = task.lower()
88
+ complex_hits = sum(1 for kw in _COMPLEX_KEYWORDS if kw in task_lower)
89
+ simple_hits = sum(1 for kw in _SIMPLE_KEYWORDS if kw in task_lower)
90
+ signals["keywords"] = min(complex_hits / 3, 1.0) - min(simple_hits / 3, 0.5)
91
+ signals["keywords"] = max(0.0, signals["keywords"])
92
+
93
+ # Signal 3: Multi-part questions (numbered steps, multiple questions)
94
+ multi_part = len(re.findall(r"\d+\.", task)) + task.count("?") - 1
95
+ signals["multi_part"] = min(max(multi_part, 0) / 5, 1.0)
96
+
97
+ # Signal 4: Code/math indicators
98
+ code_indicators = sum(1 for marker in ["```", "def ", "class ", "function", "import"]
99
+ if marker in task)
100
+ math_indicators = sum(1 for marker in ["∑", "∫", "equation", "formula", "proof"]
101
+ if marker in task_lower)
102
+ signals["technical"] = min((code_indicators + math_indicators) / 3, 1.0)
103
+
104
+ # Custom signals
105
+ for name, fn in self._custom_signals.items():
106
+ signals[name] = float(fn(task))
107
+
108
+ # Weighted average → 1-10 scale
109
+ weights = {"length": 0.15, "keywords": 0.35, "multi_part": 0.25, "technical": 0.25}
110
+ weighted_sum = sum(signals.get(k, 0) * w for k, w in weights.items())
111
+
112
+ # Add custom signals with equal weight
113
+ if self._custom_signals:
114
+ custom_avg = sum(signals.get(k, 0) for k in self._custom_signals) / len(self._custom_signals)
115
+ weighted_sum = weighted_sum * 0.7 + custom_avg * 0.3
116
+
117
+ raw_score = max(1, min(10, int(weighted_sum * 10) + 1))
118
+
119
+ category = "easy" if raw_score <= 3 else "medium" if raw_score <= 6 else "hard"
120
+
121
+ return DifficultyScore(score=raw_score, signals=signals, category=category)
@@ -0,0 +1,156 @@
1
+ """ModelSelector: pick the best model based on difficulty, cost, and capability constraints.
2
+
3
+ Combines DifficultyScorer output with CostEstimator to select the
4
+ cheapest model that meets quality requirements.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass, field
10
+ from enum import Enum
11
+
12
+ from pyagent_router.estimator import CostEstimate, CostEstimator
13
+ from pyagent_router.scorer import DifficultyScore, DifficultyScorer
14
+
15
+
16
+ class Capability(str, Enum):
17
+ """Model capabilities for filtering."""
18
+
19
+ CODE = "code"
20
+ MATH = "math"
21
+ REASONING = "reasoning"
22
+ CREATIVE = "creative"
23
+ GENERAL = "general"
24
+ VISION = "vision"
25
+
26
+
27
+ @dataclass
28
+ class ModelSpec:
29
+ """Specification of a model's capabilities and constraints.
30
+
31
+ Attributes:
32
+ name: Model identifier (must match pricing registry).
33
+ min_difficulty: Minimum difficulty score this model should handle.
34
+ max_difficulty: Maximum difficulty score this model should handle.
35
+ capabilities: Set of capabilities this model excels at.
36
+ max_context: Maximum context window in tokens.
37
+ """
38
+
39
+ name: str
40
+ min_difficulty: int = 1
41
+ max_difficulty: int = 10
42
+ capabilities: set[Capability] = field(default_factory=lambda: {Capability.GENERAL})
43
+ max_context: int = 128_000
44
+
45
+
46
+ DEFAULT_MODEL_SPECS: list[ModelSpec] = [
47
+ ModelSpec("gpt-4.1-nano", 1, 3, {Capability.GENERAL}, 1_000_000),
48
+ ModelSpec("gpt-4o-mini", 1, 5, {Capability.GENERAL, Capability.CODE}, 128_000),
49
+ ModelSpec("gpt-4.1-mini", 1, 6, {Capability.GENERAL, Capability.CODE, Capability.REASONING}, 1_000_000),
50
+ ModelSpec("gpt-4o", 3, 8, {Capability.GENERAL, Capability.CODE, Capability.CREATIVE, Capability.VISION}, 128_000),
51
+ ModelSpec("gpt-4.1", 4, 9, {Capability.GENERAL, Capability.CODE, Capability.REASONING}, 1_000_000),
52
+ ModelSpec("claude-sonnet-4", 5, 10, {Capability.GENERAL, Capability.CODE, Capability.REASONING, Capability.CREATIVE}, 200_000),
53
+ ModelSpec("o3-mini", 6, 10, {Capability.REASONING, Capability.MATH, Capability.CODE}, 200_000),
54
+ ModelSpec("o3", 8, 10, {Capability.REASONING, Capability.MATH, Capability.CODE}, 200_000),
55
+ ]
56
+
57
+
58
+ @dataclass(frozen=True)
59
+ class SelectionResult:
60
+ """Result of model selection.
61
+
62
+ Attributes:
63
+ model: Selected model name.
64
+ difficulty: The difficulty assessment.
65
+ cost_estimate: Estimated cost for this model.
66
+ reason: Human-readable explanation of why this model was chosen.
67
+ alternatives: Other models that were considered.
68
+ """
69
+
70
+ model: str
71
+ difficulty: DifficultyScore
72
+ cost_estimate: CostEstimate
73
+ reason: str
74
+ alternatives: list[str] = field(default_factory=list)
75
+
76
+
77
+ class ModelSelector:
78
+ """Select the optimal model based on task difficulty and cost.
79
+
80
+ Strategy: find the cheapest model whose difficulty range covers the task
81
+ and whose capabilities match the required capability (if specified).
82
+
83
+ Args:
84
+ specs: List of ModelSpec definitions. Defaults to built-in specs.
85
+ cost_estimator: CostEstimator instance. Created automatically if None.
86
+ scorer: DifficultyScorer instance. Created automatically if None.
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ specs: list[ModelSpec] | None = None,
92
+ cost_estimator: CostEstimator | None = None,
93
+ scorer: DifficultyScorer | None = None,
94
+ ) -> None:
95
+ self._specs = specs or list(DEFAULT_MODEL_SPECS)
96
+ self._estimator = cost_estimator or CostEstimator()
97
+ self._scorer = scorer or DifficultyScorer()
98
+
99
+ def select(
100
+ self,
101
+ task: str,
102
+ required_capability: Capability | None = None,
103
+ ) -> SelectionResult:
104
+ """Select the best model for a given task.
105
+
106
+ Args:
107
+ task: The task text to analyze.
108
+ required_capability: Optional capability filter.
109
+
110
+ Returns:
111
+ SelectionResult with chosen model and reasoning.
112
+ """
113
+ difficulty = self._scorer.score(task)
114
+
115
+ # Filter specs by difficulty range and capability
116
+ candidates: list[ModelSpec] = []
117
+ for spec in self._specs:
118
+ if spec.min_difficulty <= difficulty.score <= spec.max_difficulty:
119
+ if required_capability is None or required_capability in spec.capabilities:
120
+ candidates.append(spec)
121
+
122
+ if not candidates:
123
+ # Fallback to the most capable model
124
+ candidates = [self._specs[-1]]
125
+
126
+ # Estimate cost for each candidate and pick cheapest
127
+ scored: list[tuple[ModelSpec, CostEstimate]] = []
128
+ for spec in candidates:
129
+ try:
130
+ estimate = self._estimator.estimate_from_text(spec.name, task)
131
+ scored.append((spec, estimate))
132
+ except KeyError:
133
+ continue
134
+
135
+ if not scored:
136
+ # Last resort fallback
137
+ spec = self._specs[0]
138
+ estimate = CostEstimate(spec.name, len(task) // 4, len(task) // 8, 0.0, 0.0)
139
+ scored = [(spec, estimate)]
140
+
141
+ scored.sort(key=lambda x: x[1].total_cost)
142
+ chosen_spec, chosen_cost = scored[0]
143
+ alternatives = [s.name for s, _ in scored[1:]]
144
+
145
+ reason = (
146
+ f"Difficulty {difficulty.score}/10 ({difficulty.category}) → "
147
+ f"{chosen_spec.name} (cheapest candidate at ${chosen_cost.total_cost:.6f})"
148
+ )
149
+
150
+ return SelectionResult(
151
+ model=chosen_spec.name,
152
+ difficulty=difficulty,
153
+ cost_estimate=chosen_cost,
154
+ reason=reason,
155
+ alternatives=alternatives,
156
+ )
File without changes
@@ -0,0 +1,81 @@
1
+ """Tests for pyagent-router."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+
7
+ from pyagent_patterns.base import Agent, MockLLM
8
+ from pyagent_router.estimator import CostEstimator
9
+ from pyagent_router.middleware import RouterMiddleware
10
+ from pyagent_router.scorer import DifficultyScorer
11
+ from pyagent_router.selector import Capability, ModelSelector
12
+
13
+
14
+ def test_difficulty_scorer_easy():
15
+ scorer = DifficultyScorer()
16
+ result = scorer.score("What is 2 + 2?")
17
+ assert result.is_easy
18
+ assert result.category == "easy"
19
+
20
+
21
+ def test_difficulty_scorer_hard():
22
+ scorer = DifficultyScorer()
23
+ result = scorer.score(
24
+ "Analyze and compare the trade-offs between microservice and monolithic "
25
+ "architectures. Design an optimal system that synthesizes the benefits of both. "
26
+ "1. Evaluate scalability. 2. Evaluate maintainability. 3. Evaluate cost. "
27
+ "4. Propose a hybrid architecture. 5. Prove its optimality."
28
+ )
29
+ assert result.score >= 4 # Should be medium or hard
30
+
31
+
32
+ def test_cost_estimator_basic():
33
+ estimator = CostEstimator()
34
+ estimate = estimator.estimate("gpt-4o-mini", input_tokens=1000, output_tokens=500)
35
+ assert estimate.model == "gpt-4o-mini"
36
+ assert estimate.input_cost == 1000 * 0.15 / 1_000_000
37
+ assert estimate.output_cost == 500 * 0.60 / 1_000_000
38
+ assert estimate.total_cost == estimate.input_cost + estimate.output_cost
39
+
40
+
41
+ def test_cost_estimator_compare():
42
+ estimator = CostEstimator()
43
+ estimates = estimator.compare("Short text", models=["gpt-4o", "gpt-4o-mini"])
44
+ assert len(estimates) == 2
45
+ assert estimates[0].total_cost <= estimates[1].total_cost # sorted ascending
46
+
47
+
48
+ def test_model_selector_easy_task():
49
+ selector = ModelSelector()
50
+ result = selector.select("What is the capital of France?")
51
+ # Easy task should select a cheap model
52
+ assert result.difficulty.is_easy or result.difficulty.is_medium
53
+ assert result.model in ["gpt-4.1-nano", "gpt-4o-mini", "gpt-4.1-mini"]
54
+
55
+
56
+ def test_model_selector_with_capability():
57
+ selector = ModelSelector()
58
+ result = selector.select(
59
+ "Write a complex algorithm for graph traversal",
60
+ required_capability=Capability.CODE,
61
+ )
62
+ assert Capability.CODE in [
63
+ cap for spec in selector._specs if spec.name == result.model for cap in spec.capabilities
64
+ ]
65
+
66
+
67
+ @pytest.mark.asyncio
68
+ async def test_router_middleware():
69
+ cheap_llm = MockLLM(responses=["Cheap model response"])
70
+ expensive_llm = MockLLM(responses=["Expensive model response"])
71
+
72
+ middleware = RouterMiddleware(
73
+ model_registry={"gpt-4o-mini": cheap_llm, "gpt-4o": expensive_llm},
74
+ )
75
+ agent = Agent("test_agent", cheap_llm)
76
+ routed = middleware.wrap(agent)
77
+
78
+ from pyagent_patterns.base import Message
79
+ result = await routed.run([Message.user("What is 1+1?")])
80
+ assert result.metadata.get("routed_model") is not None
81
+ assert len(routed.routing_log) == 1