sentimentizer 0.99.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,60 @@
1
+ import logging
2
+ import sys
3
+ import time
4
+ from collections.abc import Callable
5
+ from functools import wraps
6
+ from pathlib import Path
7
+ from typing import Any, TextIO
8
+
9
+ import psutil
10
+ import structlog
11
+
12
+ file_path = Path(__file__)
13
+ root = file_path.parent.parent.absolute()
14
+
15
+
16
+ def new_logger(level: int = 20, output: TextIO = sys.stderr) -> Any:
17
+ """Creates a configured structlog logger.
18
+
19
+ Returns Any because structlog's bound logger accepts arbitrary
20
+ keyword arguments for event key-value pairs, which static type
21
+ checkers cannot express.
22
+ """
23
+ structlog.configure(
24
+ cache_logger_on_first_use=True,
25
+ wrapper_class=structlog.make_filtering_bound_logger(level),
26
+ processors=[
27
+ structlog.contextvars.merge_contextvars,
28
+ structlog.processors.add_log_level,
29
+ structlog.processors.format_exc_info,
30
+ structlog.processors.TimeStamper(fmt="iso", utc=True),
31
+ structlog.processors.JSONRenderer(),
32
+ ],
33
+ logger_factory=structlog.PrintLoggerFactory(file=output),
34
+ )
35
+ return structlog.getLogger(__name__)
36
+
37
+
38
+ logger: Any = new_logger(logging.INFO)
39
+
40
+
41
+ def time_decorator(func: Callable[..., Any]) -> Callable[..., Any]:
42
+ """logs time stats of function"""
43
+
44
+ @wraps(func)
45
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
46
+ ts = time.perf_counter()
47
+ result = func(*args, **kwargs)
48
+ te = time.perf_counter()
49
+ event = "function completed successfully"
50
+ logger.info(
51
+ event,
52
+ function=func.__name__,
53
+ run_time=f"{te - ts: 2.4f} seconds",
54
+ available_memory=f"{psutil.virtual_memory().available / 1024**3: .2f} GBs",
55
+ free_memory=f"{psutil.virtual_memory().free / 1024**3: .2f} GBs",
56
+ used_memory=f"{psutil.virtual_memory().used / 1024**3: .2f} GBs",
57
+ )
58
+ return result
59
+
60
+ return wrapper
@@ -0,0 +1,11 @@
1
+ """Sentimentizer hyperparameter tuning agent.
2
+
3
+ Uses Pydantic AI Slim (GLM 5.1 via Ollama) for LLM reasoning,
4
+ LangGraph for workflow orchestration, and Ray Tune + Optuna for
5
+ hyperparameter search.
6
+ """
7
+
8
+ from sentimentizer.agent.graph import run_agent_tuning
9
+ from sentimentizer.agent.loader import AgentConfig, TunerConfig, load_agent_config
10
+
11
+ __all__ = ["AgentConfig", "TunerConfig", "load_agent_config", "run_agent_tuning"]
@@ -0,0 +1,143 @@
1
+ """Pydantic AI agent definitions for the tuning agent.
2
+
3
+ Two agents work together:
4
+ 1. AnalysisAgent — examines training metrics and diagnoses issues
5
+ 2. StrategyAgent — decides the next tuning strategy and search space
6
+
7
+ Both use GLM 5.1 via Ollama's OpenAI-compatible API.
8
+ Pydantic AI validates the LLM's structured output, rejecting
9
+ hallucinated or invalid configurations.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from dataclasses import dataclass
15
+ from typing import Any
16
+
17
+ from pydantic_ai import Agent, RunContext
18
+ from pydantic_ai.models.openai import OpenAIModel
19
+ from pydantic_ai.providers.openai import OpenAIProvider
20
+
21
+ from sentimentizer import new_logger
22
+ from sentimentizer.agent.loader import AgentConfig
23
+ from sentimentizer.agent.models import AnalysisResult, TuningDecision
24
+ from sentimentizer.agent.prompts import ANALYSIS_SYSTEM_PROMPT, STRATEGY_SYSTEM_PROMPT
25
+ from sentimentizer.config import DEFAULT_LOG_LEVEL
26
+
27
+ logger = new_logger(DEFAULT_LOG_LEVEL)
28
+
29
+
30
+ @dataclass
31
+ class TuningDeps:
32
+ """Dependencies injected into both agents.
33
+
34
+ Provides the agents with access to tuning history, current
35
+ metrics, and the model type being tuned.
36
+ """
37
+
38
+ model_type: str = "rnn"
39
+ history: list[dict[str, Any]] | None = None
40
+ current_metrics: dict[str, Any] | None = None
41
+ search_space_defaults: dict[str, dict[str, Any]] | None = None
42
+
43
+
44
+ def _create_model(config: AgentConfig) -> OpenAIModel:
45
+ """Create an OpenAI-compatible model pointed at Ollama.
46
+
47
+ Ollama exposes an OpenAI-compatible API at /v1, so we
48
+ use pydantic-ai-slim's OpenAIModel with OpenAIProvider
49
+ configured with Ollama's base_url.
50
+ """
51
+ provider = OpenAIProvider(
52
+ base_url=config.ollama_base_url,
53
+ api_key="ollama", # Ollama doesn't require a real key
54
+ )
55
+ return OpenAIModel(
56
+ model_name=config.model_name,
57
+ provider=provider,
58
+ )
59
+
60
+
61
+ def create_analysis_agent(config: AgentConfig) -> Agent[TuningDeps, AnalysisResult]:
62
+ """Create the analysis agent that examines training metrics.
63
+
64
+ This agent:
65
+ - Reads validation loss, accuracy, and training loss curves
66
+ - Detects overfitting (val_loss rising while train_loss falls)
67
+ - Detects underfitting (both losses remain high)
68
+ - Assesses whether learning rate is appropriate
69
+ - Suggests which parameters to focus on next
70
+
71
+ Returns a Pydantic-validated AnalysisResult.
72
+ """
73
+ model = _create_model(config)
74
+
75
+ agent = Agent(
76
+ model=model,
77
+ output_type=AnalysisResult,
78
+ deps_type=TuningDeps,
79
+ system_prompt=ANALYSIS_SYSTEM_PROMPT,
80
+ model_settings={"temperature": config.temperature, "max_tokens": config.max_tokens},
81
+ )
82
+
83
+ @agent.tool
84
+ def get_previous_results(ctx: RunContext[TuningDeps]) -> list[dict[str, Any]]:
85
+ """Get the history of past tuning iterations."""
86
+ return ctx.deps.history or []
87
+
88
+ @agent.tool
89
+ def get_current_metrics(ctx: RunContext[TuningDeps]) -> dict[str, Any]:
90
+ """Get the current iteration's training metrics."""
91
+ return ctx.deps.current_metrics or {}
92
+
93
+ @agent.tool
94
+ def get_model_type(ctx: RunContext[TuningDeps]) -> str:
95
+ """Get the model type being tuned (rnn, encoder, decoder)."""
96
+ return ctx.deps.model_type
97
+
98
+ return agent
99
+
100
+
101
+ def create_strategy_agent(config: AgentConfig) -> Agent[TuningDeps, TuningDecision]:
102
+ """Create the strategy agent that decides the next tuning action.
103
+
104
+ This agent:
105
+ - Receives the analysis of training metrics
106
+ - Decides whether to widen, narrow, change focus, or stop
107
+ - Produces a new search space configuration
108
+ - Sets the number of trials for the next Ray Tune run
109
+
110
+ Returns a Pydantic-validated TuningDecision, ensuring
111
+ the search space and strategy are always valid.
112
+ """
113
+ model = _create_model(config)
114
+
115
+ agent = Agent(
116
+ model=model,
117
+ output_type=TuningDecision,
118
+ deps_type=TuningDeps,
119
+ system_prompt=STRATEGY_SYSTEM_PROMPT,
120
+ model_settings={"temperature": config.temperature, "max_tokens": config.max_tokens},
121
+ )
122
+
123
+ @agent.tool
124
+ def get_analysis(ctx: RunContext[TuningDeps]) -> dict[str, Any]:
125
+ """Get the current analysis results from the analysis agent."""
126
+ return ctx.deps.current_metrics or {}
127
+
128
+ @agent.tool
129
+ def get_previous_results(ctx: RunContext[TuningDeps]) -> list[dict[str, Any]]:
130
+ """Get the history of past tuning iterations."""
131
+ return ctx.deps.history or []
132
+
133
+ @agent.tool
134
+ def get_default_search_space(ctx: RunContext[TuningDeps]) -> dict[str, dict[str, Any]]:
135
+ """Get the default search space parameters from the YAML config."""
136
+ return ctx.deps.search_space_defaults or {}
137
+
138
+ @agent.tool
139
+ def get_model_type(ctx: RunContext[TuningDeps]) -> str:
140
+ """Get the model type being tuned (rnn, encoder, decoder)."""
141
+ return ctx.deps.model_type
142
+
143
+ return agent
@@ -0,0 +1,116 @@
1
+ # Sentimentizer Agent Configuration
2
+ # Controls the hyperparameter tuning agent powered by GLM 5.1 via Ollama
3
+
4
+ agent:
5
+ # LLM configuration (Ollama OpenAI-compatible endpoint)
6
+ model_name: "glm-5.1:cloud"
7
+ ollama_base_url: "http://localhost:11434/v1"
8
+ temperature: 0.3
9
+ max_tokens: 2048
10
+
11
+ # Agent loop configuration
12
+ max_iterations: 5
13
+ convergence_threshold: 0.005 # Stop if improvement < this over last 3 iterations
14
+ initial_search_strategy: "bayesian" # bayesian | grid | random
15
+
16
+ # LangGraph checkpointing (resume interrupted tuning sessions)
17
+ checkpointing:
18
+ enabled: true
19
+ db_path: "agent_checkpoints.db"
20
+
21
+ # Human-in-the-loop (pause before expensive training runs)
22
+ human_in_the_loop: false
23
+
24
+ tuner:
25
+ # Ray Tune + Optuna scheduler configuration
26
+ scheduler: "asha" # asha | hyperband | median
27
+ metric: "val_accuracy"
28
+ mode: "max"
29
+ num_samples: 20
30
+ grace_period: 2
31
+ reduction_factor: 3
32
+
33
+ # Search spaces per model type
34
+ # Each parameter supports: loguniform, uniform, choice, randint
35
+ search_spaces:
36
+ rnn:
37
+ lr:
38
+ type: "loguniform"
39
+ low: 1.0e-5
40
+ high: 1.0e-2
41
+ hidden_size:
42
+ type: "choice"
43
+ values: [128, 256, 512]
44
+ num_layers:
45
+ type: "randint"
46
+ low: 1
47
+ high: 4
48
+ dropout:
49
+ type: "uniform"
50
+ low: 0.1
51
+ high: 0.5
52
+ weight_decay:
53
+ type: "loguniform"
54
+ low: 1.0e-6
55
+ high: 1.0e-2
56
+ batch_size:
57
+ type: "choice"
58
+ values: [32, 64, 128]
59
+
60
+ encoder:
61
+ lr:
62
+ type: "loguniform"
63
+ low: 1.0e-5
64
+ high: 5.0e-3
65
+ d_model:
66
+ type: "choice"
67
+ values: [128, 256, 512]
68
+ n_heads:
69
+ type: "choice"
70
+ values: [2, 4, 8]
71
+ n_layers:
72
+ type: "randint"
73
+ low: 1
74
+ high: 6
75
+ dropout:
76
+ type: "uniform"
77
+ low: 0.1
78
+ high: 0.5
79
+ weight_decay:
80
+ type: "loguniform"
81
+ low: 1.0e-6
82
+ high: 1.0e-2
83
+ batch_size:
84
+ type: "choice"
85
+ values: [32, 64, 128]
86
+
87
+ decoder:
88
+ lr:
89
+ type: "loguniform"
90
+ low: 1.0e-5
91
+ high: 1.0e-2
92
+ d_model:
93
+ type: "choice"
94
+ values: [128, 256, 512]
95
+ n_heads:
96
+ type: "choice"
97
+ values: [2, 4, 8]
98
+ n_encoder_layers:
99
+ type: "randint"
100
+ low: 1
101
+ high: 4
102
+ n_decoder_layers:
103
+ type: "randint"
104
+ low: 2
105
+ high: 6
106
+ dropout:
107
+ type: "uniform"
108
+ low: 0.1
109
+ high: 0.5
110
+ weight_decay:
111
+ type: "loguniform"
112
+ low: 1.0e-6
113
+ high: 1.0e-2
114
+ batch_size:
115
+ type: "choice"
116
+ values: [32, 64, 128]
@@ -0,0 +1,173 @@
1
+ """LangGraph state graph for the tuning agent.
2
+
3
+ Defines the workflow: analyze → decide → tune → evaluate → (loop or end)
4
+
5
+ Uses LangGraph for orchestration and checkpointing, with Pydantic AI
6
+ agents (GLM 5.1 via Ollama) as the LLM reasoning layer inside nodes.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ from langgraph.graph import END, StateGraph
15
+
16
+ from sentimentizer import new_logger
17
+ from sentimentizer.agent.loader import load_agent_config
18
+ from sentimentizer.agent.models import AgentRunResult
19
+ from sentimentizer.agent.nodes import analyze, decide, evaluate, tune
20
+ from sentimentizer.agent.state import AgentState
21
+ from sentimentizer.config import DEFAULT_LOG_LEVEL
22
+
23
+ logger = new_logger(DEFAULT_LOG_LEVEL)
24
+
25
+
26
+ def should_continue(state: AgentState) -> str:
27
+ """Conditional edge: decide whether to continue or end the loop.
28
+
29
+ Returns 'analyze' to continue iterating, or 'end' to stop.
30
+ """
31
+ if state.get("converged", False):
32
+ return "end"
33
+
34
+ # Safety check: if iteration exceeds a hard limit, stop
35
+ iteration = state.get("iteration", 0)
36
+ config_path = state.get("agent_config_path")
37
+ try:
38
+ agent_config, _ = load_agent_config(config_path)
39
+ hard_limit = agent_config.max_iterations * 2 # allow some extra
40
+ except Exception:
41
+ hard_limit = 20
42
+
43
+ if iteration >= hard_limit:
44
+ logger.info("hard_limit_reached", iteration=iteration, limit=hard_limit)
45
+ return "end"
46
+
47
+ return "analyze"
48
+
49
+
50
+ def build_graph(
51
+ model_type: str = "rnn",
52
+ config_path: str | Path | None = None,
53
+ ) -> StateGraph:
54
+ """Build the LangGraph tuning agent graph.
55
+
56
+ The graph has four nodes in a loop:
57
+ 1. analyze — Pydantic AI analysis agent examines metrics
58
+ 2. decide — Pydantic AI strategy agent chooses next action
59
+ 3. tune — Ray Tune + Optuna executes the search
60
+ 4. evaluate — Check convergence, update best results
61
+
62
+ After evaluate, a conditional edge either loops back to
63
+ analyze or goes to END.
64
+
65
+ Args:
66
+ model_type: Model to tune ('rnn', 'encoder', 'decoder').
67
+ config_path: Path to agent config YAML.
68
+
69
+ Returns:
70
+ Compiled StateGraph ready to run.
71
+ """
72
+ graph = StateGraph(AgentState)
73
+
74
+ # Add nodes
75
+ graph.add_node("analyze", analyze)
76
+ graph.add_node("decide", decide)
77
+ graph.add_node("tune", tune)
78
+ graph.add_node("evaluate", evaluate)
79
+
80
+ # Define edges
81
+ graph.set_entry_point("analyze")
82
+ graph.add_edge("analyze", "decide")
83
+ graph.add_edge("decide", "tune")
84
+ graph.add_edge("tune", "evaluate")
85
+
86
+ # Conditional edge after evaluate: loop back or end
87
+ graph.add_conditional_edges(
88
+ "evaluate",
89
+ should_continue,
90
+ {"analyze": "analyze", "end": END},
91
+ )
92
+
93
+ return graph.compile()
94
+
95
+
96
+ def create_initial_state(
97
+ model_type: str = "rnn",
98
+ config_path: str | Path | None = None,
99
+ ) -> dict[str, Any]:
100
+ """Create the initial AgentState for a new tuning run.
101
+
102
+ Args:
103
+ model_type: Model to tune ('rnn', 'encoder', 'decoder').
104
+ config_path: Path to agent config YAML.
105
+
106
+ Returns:
107
+ Initial state dict for the graph.
108
+ """
109
+ return {
110
+ "iteration": 0,
111
+ "model_type": model_type,
112
+ "history": [],
113
+ "best_config": {},
114
+ "best_accuracy": 0.0,
115
+ "best_loss": float("inf"),
116
+ "search_space_overrides": {},
117
+ "agent_config_path": str(config_path) if config_path else None,
118
+ "converged": False,
119
+ }
120
+
121
+
122
+ async def run_agent_tuning(
123
+ model_type: str = "rnn",
124
+ config_path: str | Path | None = None,
125
+ ) -> AgentRunResult:
126
+ """Run the complete agent tuning loop.
127
+
128
+ This is the main entry point for the tuning agent. It:
129
+ 1. Builds the LangGraph state graph
130
+ 2. Creates initial state
131
+ 3. Runs the graph to completion
132
+ 4. Returns the best configuration found
133
+
134
+ Args:
135
+ model_type: Model to tune ('rnn', 'encoder', 'decoder').
136
+ config_path: Path to agent config YAML (uses default if None).
137
+
138
+ Returns:
139
+ AgentRunResult with the best configuration and full history.
140
+ """
141
+ graph = build_graph(model_type, config_path)
142
+ initial_state = create_initial_state(model_type, config_path)
143
+
144
+ logger.info(
145
+ "starting_agent_tuning",
146
+ model_type=model_type,
147
+ config_path=str(config_path) if config_path else "default",
148
+ )
149
+
150
+ # Run the graph
151
+ final_state = await graph.ainvoke(initial_state)
152
+
153
+ # Extract final result
154
+ final_result = final_state.get("final_result")
155
+ if final_result is None:
156
+ # Build result from state if graph didn't set it explicitly
157
+ final_result = AgentRunResult(
158
+ best_config=final_state.get("best_config", {}),
159
+ best_accuracy=final_state.get("best_accuracy", 0.0),
160
+ best_loss=final_state.get("best_loss", float("inf")),
161
+ iterations_completed=final_state.get("iteration", 0),
162
+ converged=final_state.get("converged", False),
163
+ )
164
+
165
+ logger.info(
166
+ "agent_tuning_complete",
167
+ best_accuracy=final_result.best_accuracy,
168
+ best_loss=final_result.best_loss,
169
+ iterations=final_result.iterations_completed,
170
+ converged=final_result.converged,
171
+ )
172
+
173
+ return final_result
@@ -0,0 +1,90 @@
1
+ """Load agent and tuner configuration from YAML.
2
+
3
+ Provides dataclass-backed configuration loading with sensible defaults,
4
+ so the YAML file only needs to override what differs from defaults.
5
+
6
+ TunerConfig and load_search_space live in sentimentizer.tuner (the
7
+ standalone tuning module). This module re-exports them for convenience
8
+ and adds AgentConfig, which is agent-specific.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from dataclasses import dataclass, field
14
+ from pathlib import Path
15
+ from typing import Any
16
+
17
+ import yaml
18
+
19
+ from sentimentizer.tuner import TunerConfig
20
+
21
+
22
+ @dataclass
23
+ class CheckpointConfig:
24
+ """LangGraph checkpointing configuration."""
25
+
26
+ enabled: bool = True
27
+ db_path: str = "agent_checkpoints.db"
28
+
29
+
30
+ @dataclass
31
+ class AgentConfig:
32
+ """Configuration for the Pydantic AI tuning agent.
33
+
34
+ Connects to GLM 5.1 (or any model) via Ollama's OpenAI-compatible API.
35
+ """
36
+
37
+ model_name: str = "glm-5.1:cloud"
38
+ ollama_base_url: str = "http://localhost:11434/v1"
39
+ temperature: float = 0.3
40
+ max_tokens: int = 2048
41
+ max_iterations: int = 5
42
+ convergence_threshold: float = 0.005
43
+ initial_search_strategy: str = "bayesian"
44
+ checkpointing: CheckpointConfig = field(default_factory=CheckpointConfig)
45
+ human_in_the_loop: bool = False
46
+
47
+
48
+ def _get_default_config_path() -> Path:
49
+ """Return path to the default config.yaml bundled with the package."""
50
+ return Path(__file__).parent / "config.yaml"
51
+
52
+
53
+ def _config_dict_to_agent_config(agent_dict: dict[str, Any]) -> AgentConfig:
54
+ """Convert the agent section of the YAML dict into an AgentConfig."""
55
+ cp_dict = agent_dict.pop("checkpointing", {})
56
+ checkpoint_cfg = CheckpointConfig(**cp_dict)
57
+ return AgentConfig(checkpointing=checkpoint_cfg, **agent_dict)
58
+
59
+
60
+ def load_agent_config(
61
+ path: str | Path | None = None,
62
+ ) -> tuple[AgentConfig, TunerConfig]:
63
+ """Load agent and tuner configuration from a YAML file.
64
+
65
+ Args:
66
+ path: Path to YAML config file. If None, uses the default
67
+ config.yaml bundled with the package.
68
+
69
+ Returns:
70
+ Tuple of (AgentConfig, TunerConfig) with all settings populated.
71
+
72
+ Raises:
73
+ FileNotFoundError: If the config file doesn't exist.
74
+ """
75
+ path = _get_default_config_path() if path is None else Path(path)
76
+
77
+ if not path.exists():
78
+ raise FileNotFoundError(f"Agent config file not found: {path}")
79
+
80
+ with open(path) as f:
81
+ raw = yaml.safe_load(f)
82
+
83
+ agent_cfg = _config_dict_to_agent_config(raw.get("agent", {}))
84
+ tuner_cfg = TunerConfig(**raw.get("tuner", {}))
85
+
86
+ return agent_cfg, tuner_cfg
87
+
88
+
89
+ # Allow overriding config path via environment variable
90
+ CONFIG_PATH_ENV = "SENTIMENTIZER_AGENT_CONFIG"
@@ -0,0 +1,104 @@
1
+ """Pydantic models for the tuning agent's structured input/output.
2
+
3
+ These models validate the LLM's responses and the tuning results,
4
+ ensuring that only valid configurations are passed to Ray Tune.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import Literal
10
+
11
+ from pydantic import BaseModel, Field
12
+
13
+
14
+ class SearchSpaceParam(BaseModel):
15
+ """A single parameter's search space specification.
16
+
17
+ Produced by the LLM agent when it decides to narrow/widen
18
+ the search space for a hyperparameter.
19
+ """
20
+
21
+ type: Literal["loguniform", "uniform", "choice", "randint"]
22
+ low: float | None = None
23
+ high: float | None = None
24
+ values: list[int] | list[float] | None = None
25
+
26
+
27
+ class TuningDecision(BaseModel):
28
+ """Structured output from the strategy agent.
29
+
30
+ The LLM produces this after analyzing training metrics.
31
+ Pydantic validates it, rejecting hallucinated or invalid configs.
32
+ """
33
+
34
+ reasoning: str = Field(
35
+ ...,
36
+ description="Brief explanation of why this strategy was chosen",
37
+ )
38
+ strategy: Literal["widen", "narrow", "change_focus", "increase_epochs", "stop"] = Field(
39
+ ...,
40
+ description="The tuning strategy to apply next",
41
+ )
42
+ search_space: dict[str, SearchSpaceParam] = Field(
43
+ ...,
44
+ description="Updated search space parameters for Ray Tune",
45
+ )
46
+ num_samples: int = Field(
47
+ default=20,
48
+ ge=5,
49
+ le=100,
50
+ description="Number of trials for Ray Tune to run",
51
+ )
52
+
53
+
54
+ class TuningResult(BaseModel):
55
+ """Result from a single Ray Tune tuning run.
56
+
57
+ Fed back to the LLM agent for analysis in the next iteration.
58
+ """
59
+
60
+ best_accuracy: float = Field(..., description="Best validation accuracy achieved")
61
+ best_loss: float = Field(..., description="Best validation loss achieved")
62
+ best_config: dict[str, float | int] = Field(
63
+ ..., description="Best hyperparameter configuration found"
64
+ )
65
+ trial_count: int = Field(..., description="Number of trials completed")
66
+ improvement_over_last: float = Field(
67
+ default=0.0, description="Accuracy improvement vs. previous best"
68
+ )
69
+
70
+
71
+ class AnalysisResult(BaseModel):
72
+ """Structured output from the analysis agent.
73
+
74
+ The LLM produces this after examining training metrics and history.
75
+ """
76
+
77
+ summary: str = Field(..., description="Brief summary of the training results")
78
+ overfitting: bool = Field(
79
+ default=False, description="Whether the model appears to be overfitting"
80
+ )
81
+ underfitting: bool = Field(
82
+ default=False, description="Whether the model appears to be underfitting"
83
+ )
84
+ lr_status: Literal["too_high", "too_low", "appropriate", "unclear"] = Field(
85
+ default="unclear", description="Assessment of the learning rate"
86
+ )
87
+ suggested_focus: list[str] = Field(
88
+ default_factory=list,
89
+ description="Parameters to focus on in the next iteration",
90
+ )
91
+
92
+
93
+ class AgentRunResult(BaseModel):
94
+ """Final result from the complete agent tuning loop.
95
+
96
+ Returned when the agent converges or reaches max iterations.
97
+ """
98
+
99
+ best_config: dict[str, float | int]
100
+ best_accuracy: float
101
+ best_loss: float
102
+ iterations_completed: int
103
+ converged: bool
104
+ history: list[TuningResult] = Field(default_factory=list)