sentimentizer 0.99.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sentimentizer/__init__.py +60 -0
- sentimentizer/agent/__init__.py +11 -0
- sentimentizer/agent/agents.py +143 -0
- sentimentizer/agent/config.yaml +116 -0
- sentimentizer/agent/graph.py +173 -0
- sentimentizer/agent/loader.py +90 -0
- sentimentizer/agent/models.py +104 -0
- sentimentizer/agent/nodes.py +297 -0
- sentimentizer/agent/prompts.py +78 -0
- sentimentizer/agent/state.py +49 -0
- sentimentizer/config.py +210 -0
- sentimentizer/data/.gitignore +2 -0
- sentimentizer/data/__init__.py +0 -0
- sentimentizer/data/weights.pth +0 -0
- sentimentizer/data/yelp.dictionary +0 -0
- sentimentizer/extractor.py +119 -0
- sentimentizer/loader.py +60 -0
- sentimentizer/models/__init__.py +0 -0
- sentimentizer/models/decoder.py +248 -0
- sentimentizer/models/encoder.py +242 -0
- sentimentizer/models/rnn.py +220 -0
- sentimentizer/serve.py +351 -0
- sentimentizer/tokenizer.py +192 -0
- sentimentizer/trainer.py +583 -0
- sentimentizer/tuner.py +507 -0
- sentimentizer-0.99.0.dist-info/METADATA +489 -0
- sentimentizer-0.99.0.dist-info/RECORD +29 -0
- sentimentizer-0.99.0.dist-info/WHEEL +4 -0
- sentimentizer-0.99.0.dist-info/licenses/LICENSE +20 -0
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import sys
|
|
3
|
+
import time
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from functools import wraps
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, TextIO
|
|
8
|
+
|
|
9
|
+
import psutil
|
|
10
|
+
import structlog
|
|
11
|
+
|
|
12
|
+
file_path = Path(__file__)
|
|
13
|
+
root = file_path.parent.parent.absolute()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def new_logger(level: int = 20, output: TextIO = sys.stderr) -> Any:
|
|
17
|
+
"""Creates a configured structlog logger.
|
|
18
|
+
|
|
19
|
+
Returns Any because structlog's bound logger accepts arbitrary
|
|
20
|
+
keyword arguments for event key-value pairs, which static type
|
|
21
|
+
checkers cannot express.
|
|
22
|
+
"""
|
|
23
|
+
structlog.configure(
|
|
24
|
+
cache_logger_on_first_use=True,
|
|
25
|
+
wrapper_class=structlog.make_filtering_bound_logger(level),
|
|
26
|
+
processors=[
|
|
27
|
+
structlog.contextvars.merge_contextvars,
|
|
28
|
+
structlog.processors.add_log_level,
|
|
29
|
+
structlog.processors.format_exc_info,
|
|
30
|
+
structlog.processors.TimeStamper(fmt="iso", utc=True),
|
|
31
|
+
structlog.processors.JSONRenderer(),
|
|
32
|
+
],
|
|
33
|
+
logger_factory=structlog.PrintLoggerFactory(file=output),
|
|
34
|
+
)
|
|
35
|
+
return structlog.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
logger: Any = new_logger(logging.INFO)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def time_decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
42
|
+
"""logs time stats of function"""
|
|
43
|
+
|
|
44
|
+
@wraps(func)
|
|
45
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
46
|
+
ts = time.perf_counter()
|
|
47
|
+
result = func(*args, **kwargs)
|
|
48
|
+
te = time.perf_counter()
|
|
49
|
+
event = "function completed successfully"
|
|
50
|
+
logger.info(
|
|
51
|
+
event,
|
|
52
|
+
function=func.__name__,
|
|
53
|
+
run_time=f"{te - ts: 2.4f} seconds",
|
|
54
|
+
available_memory=f"{psutil.virtual_memory().available / 1024**3: .2f} GBs",
|
|
55
|
+
free_memory=f"{psutil.virtual_memory().free / 1024**3: .2f} GBs",
|
|
56
|
+
used_memory=f"{psutil.virtual_memory().used / 1024**3: .2f} GBs",
|
|
57
|
+
)
|
|
58
|
+
return result
|
|
59
|
+
|
|
60
|
+
return wrapper
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Sentimentizer hyperparameter tuning agent.
|
|
2
|
+
|
|
3
|
+
Uses Pydantic AI Slim (GLM 5.1 via Ollama) for LLM reasoning,
|
|
4
|
+
LangGraph for workflow orchestration, and Ray Tune + Optuna for
|
|
5
|
+
hyperparameter search.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from sentimentizer.agent.graph import run_agent_tuning
|
|
9
|
+
from sentimentizer.agent.loader import AgentConfig, TunerConfig, load_agent_config
|
|
10
|
+
|
|
11
|
+
__all__ = ["AgentConfig", "TunerConfig", "load_agent_config", "run_agent_tuning"]
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""Pydantic AI agent definitions for the tuning agent.
|
|
2
|
+
|
|
3
|
+
Two agents work together:
|
|
4
|
+
1. AnalysisAgent — examines training metrics and diagnoses issues
|
|
5
|
+
2. StrategyAgent — decides the next tuning strategy and search space
|
|
6
|
+
|
|
7
|
+
Both use GLM 5.1 via Ollama's OpenAI-compatible API.
|
|
8
|
+
Pydantic AI validates the LLM's structured output, rejecting
|
|
9
|
+
hallucinated or invalid configurations.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
from pydantic_ai import Agent, RunContext
|
|
18
|
+
from pydantic_ai.models.openai import OpenAIModel
|
|
19
|
+
from pydantic_ai.providers.openai import OpenAIProvider
|
|
20
|
+
|
|
21
|
+
from sentimentizer import new_logger
|
|
22
|
+
from sentimentizer.agent.loader import AgentConfig
|
|
23
|
+
from sentimentizer.agent.models import AnalysisResult, TuningDecision
|
|
24
|
+
from sentimentizer.agent.prompts import ANALYSIS_SYSTEM_PROMPT, STRATEGY_SYSTEM_PROMPT
|
|
25
|
+
from sentimentizer.config import DEFAULT_LOG_LEVEL
|
|
26
|
+
|
|
27
|
+
logger = new_logger(DEFAULT_LOG_LEVEL)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class TuningDeps:
|
|
32
|
+
"""Dependencies injected into both agents.
|
|
33
|
+
|
|
34
|
+
Provides the agents with access to tuning history, current
|
|
35
|
+
metrics, and the model type being tuned.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
model_type: str = "rnn"
|
|
39
|
+
history: list[dict[str, Any]] | None = None
|
|
40
|
+
current_metrics: dict[str, Any] | None = None
|
|
41
|
+
search_space_defaults: dict[str, dict[str, Any]] | None = None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _create_model(config: AgentConfig) -> OpenAIModel:
|
|
45
|
+
"""Create an OpenAI-compatible model pointed at Ollama.
|
|
46
|
+
|
|
47
|
+
Ollama exposes an OpenAI-compatible API at /v1, so we
|
|
48
|
+
use pydantic-ai-slim's OpenAIModel with OpenAIProvider
|
|
49
|
+
configured with Ollama's base_url.
|
|
50
|
+
"""
|
|
51
|
+
provider = OpenAIProvider(
|
|
52
|
+
base_url=config.ollama_base_url,
|
|
53
|
+
api_key="ollama", # Ollama doesn't require a real key
|
|
54
|
+
)
|
|
55
|
+
return OpenAIModel(
|
|
56
|
+
model_name=config.model_name,
|
|
57
|
+
provider=provider,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def create_analysis_agent(config: AgentConfig) -> Agent[TuningDeps, AnalysisResult]:
|
|
62
|
+
"""Create the analysis agent that examines training metrics.
|
|
63
|
+
|
|
64
|
+
This agent:
|
|
65
|
+
- Reads validation loss, accuracy, and training loss curves
|
|
66
|
+
- Detects overfitting (val_loss rising while train_loss falls)
|
|
67
|
+
- Detects underfitting (both losses remain high)
|
|
68
|
+
- Assesses whether learning rate is appropriate
|
|
69
|
+
- Suggests which parameters to focus on next
|
|
70
|
+
|
|
71
|
+
Returns a Pydantic-validated AnalysisResult.
|
|
72
|
+
"""
|
|
73
|
+
model = _create_model(config)
|
|
74
|
+
|
|
75
|
+
agent = Agent(
|
|
76
|
+
model=model,
|
|
77
|
+
output_type=AnalysisResult,
|
|
78
|
+
deps_type=TuningDeps,
|
|
79
|
+
system_prompt=ANALYSIS_SYSTEM_PROMPT,
|
|
80
|
+
model_settings={"temperature": config.temperature, "max_tokens": config.max_tokens},
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
@agent.tool
|
|
84
|
+
def get_previous_results(ctx: RunContext[TuningDeps]) -> list[dict[str, Any]]:
|
|
85
|
+
"""Get the history of past tuning iterations."""
|
|
86
|
+
return ctx.deps.history or []
|
|
87
|
+
|
|
88
|
+
@agent.tool
|
|
89
|
+
def get_current_metrics(ctx: RunContext[TuningDeps]) -> dict[str, Any]:
|
|
90
|
+
"""Get the current iteration's training metrics."""
|
|
91
|
+
return ctx.deps.current_metrics or {}
|
|
92
|
+
|
|
93
|
+
@agent.tool
|
|
94
|
+
def get_model_type(ctx: RunContext[TuningDeps]) -> str:
|
|
95
|
+
"""Get the model type being tuned (rnn, encoder, decoder)."""
|
|
96
|
+
return ctx.deps.model_type
|
|
97
|
+
|
|
98
|
+
return agent
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def create_strategy_agent(config: AgentConfig) -> Agent[TuningDeps, TuningDecision]:
|
|
102
|
+
"""Create the strategy agent that decides the next tuning action.
|
|
103
|
+
|
|
104
|
+
This agent:
|
|
105
|
+
- Receives the analysis of training metrics
|
|
106
|
+
- Decides whether to widen, narrow, change focus, or stop
|
|
107
|
+
- Produces a new search space configuration
|
|
108
|
+
- Sets the number of trials for the next Ray Tune run
|
|
109
|
+
|
|
110
|
+
Returns a Pydantic-validated TuningDecision, ensuring
|
|
111
|
+
the search space and strategy are always valid.
|
|
112
|
+
"""
|
|
113
|
+
model = _create_model(config)
|
|
114
|
+
|
|
115
|
+
agent = Agent(
|
|
116
|
+
model=model,
|
|
117
|
+
output_type=TuningDecision,
|
|
118
|
+
deps_type=TuningDeps,
|
|
119
|
+
system_prompt=STRATEGY_SYSTEM_PROMPT,
|
|
120
|
+
model_settings={"temperature": config.temperature, "max_tokens": config.max_tokens},
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
@agent.tool
|
|
124
|
+
def get_analysis(ctx: RunContext[TuningDeps]) -> dict[str, Any]:
|
|
125
|
+
"""Get the current analysis results from the analysis agent."""
|
|
126
|
+
return ctx.deps.current_metrics or {}
|
|
127
|
+
|
|
128
|
+
@agent.tool
|
|
129
|
+
def get_previous_results(ctx: RunContext[TuningDeps]) -> list[dict[str, Any]]:
|
|
130
|
+
"""Get the history of past tuning iterations."""
|
|
131
|
+
return ctx.deps.history or []
|
|
132
|
+
|
|
133
|
+
@agent.tool
|
|
134
|
+
def get_default_search_space(ctx: RunContext[TuningDeps]) -> dict[str, dict[str, Any]]:
|
|
135
|
+
"""Get the default search space parameters from the YAML config."""
|
|
136
|
+
return ctx.deps.search_space_defaults or {}
|
|
137
|
+
|
|
138
|
+
@agent.tool
|
|
139
|
+
def get_model_type(ctx: RunContext[TuningDeps]) -> str:
|
|
140
|
+
"""Get the model type being tuned (rnn, encoder, decoder)."""
|
|
141
|
+
return ctx.deps.model_type
|
|
142
|
+
|
|
143
|
+
return agent
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
# Sentimentizer Agent Configuration
|
|
2
|
+
# Controls the hyperparameter tuning agent powered by GLM 5.1 via Ollama
|
|
3
|
+
|
|
4
|
+
agent:
|
|
5
|
+
# LLM configuration (Ollama OpenAI-compatible endpoint)
|
|
6
|
+
model_name: "glm-5.1:cloud"
|
|
7
|
+
ollama_base_url: "http://localhost:11434/v1"
|
|
8
|
+
temperature: 0.3
|
|
9
|
+
max_tokens: 2048
|
|
10
|
+
|
|
11
|
+
# Agent loop configuration
|
|
12
|
+
max_iterations: 5
|
|
13
|
+
convergence_threshold: 0.005 # Stop if improvement < this over last 3 iterations
|
|
14
|
+
initial_search_strategy: "bayesian" # bayesian | grid | random
|
|
15
|
+
|
|
16
|
+
# LangGraph checkpointing (resume interrupted tuning sessions)
|
|
17
|
+
checkpointing:
|
|
18
|
+
enabled: true
|
|
19
|
+
db_path: "agent_checkpoints.db"
|
|
20
|
+
|
|
21
|
+
# Human-in-the-loop (pause before expensive training runs)
|
|
22
|
+
human_in_the_loop: false
|
|
23
|
+
|
|
24
|
+
tuner:
|
|
25
|
+
# Ray Tune + Optuna scheduler configuration
|
|
26
|
+
scheduler: "asha" # asha | hyperband | median
|
|
27
|
+
metric: "val_accuracy"
|
|
28
|
+
mode: "max"
|
|
29
|
+
num_samples: 20
|
|
30
|
+
grace_period: 2
|
|
31
|
+
reduction_factor: 3
|
|
32
|
+
|
|
33
|
+
# Search spaces per model type
|
|
34
|
+
# Each parameter supports: loguniform, uniform, choice, randint
|
|
35
|
+
search_spaces:
|
|
36
|
+
rnn:
|
|
37
|
+
lr:
|
|
38
|
+
type: "loguniform"
|
|
39
|
+
low: 1.0e-5
|
|
40
|
+
high: 1.0e-2
|
|
41
|
+
hidden_size:
|
|
42
|
+
type: "choice"
|
|
43
|
+
values: [128, 256, 512]
|
|
44
|
+
num_layers:
|
|
45
|
+
type: "randint"
|
|
46
|
+
low: 1
|
|
47
|
+
high: 4
|
|
48
|
+
dropout:
|
|
49
|
+
type: "uniform"
|
|
50
|
+
low: 0.1
|
|
51
|
+
high: 0.5
|
|
52
|
+
weight_decay:
|
|
53
|
+
type: "loguniform"
|
|
54
|
+
low: 1.0e-6
|
|
55
|
+
high: 1.0e-2
|
|
56
|
+
batch_size:
|
|
57
|
+
type: "choice"
|
|
58
|
+
values: [32, 64, 128]
|
|
59
|
+
|
|
60
|
+
encoder:
|
|
61
|
+
lr:
|
|
62
|
+
type: "loguniform"
|
|
63
|
+
low: 1.0e-5
|
|
64
|
+
high: 5.0e-3
|
|
65
|
+
d_model:
|
|
66
|
+
type: "choice"
|
|
67
|
+
values: [128, 256, 512]
|
|
68
|
+
n_heads:
|
|
69
|
+
type: "choice"
|
|
70
|
+
values: [2, 4, 8]
|
|
71
|
+
n_layers:
|
|
72
|
+
type: "randint"
|
|
73
|
+
low: 1
|
|
74
|
+
high: 6
|
|
75
|
+
dropout:
|
|
76
|
+
type: "uniform"
|
|
77
|
+
low: 0.1
|
|
78
|
+
high: 0.5
|
|
79
|
+
weight_decay:
|
|
80
|
+
type: "loguniform"
|
|
81
|
+
low: 1.0e-6
|
|
82
|
+
high: 1.0e-2
|
|
83
|
+
batch_size:
|
|
84
|
+
type: "choice"
|
|
85
|
+
values: [32, 64, 128]
|
|
86
|
+
|
|
87
|
+
decoder:
|
|
88
|
+
lr:
|
|
89
|
+
type: "loguniform"
|
|
90
|
+
low: 1.0e-5
|
|
91
|
+
high: 1.0e-2
|
|
92
|
+
d_model:
|
|
93
|
+
type: "choice"
|
|
94
|
+
values: [128, 256, 512]
|
|
95
|
+
n_heads:
|
|
96
|
+
type: "choice"
|
|
97
|
+
values: [2, 4, 8]
|
|
98
|
+
n_encoder_layers:
|
|
99
|
+
type: "randint"
|
|
100
|
+
low: 1
|
|
101
|
+
high: 4
|
|
102
|
+
n_decoder_layers:
|
|
103
|
+
type: "randint"
|
|
104
|
+
low: 2
|
|
105
|
+
high: 6
|
|
106
|
+
dropout:
|
|
107
|
+
type: "uniform"
|
|
108
|
+
low: 0.1
|
|
109
|
+
high: 0.5
|
|
110
|
+
weight_decay:
|
|
111
|
+
type: "loguniform"
|
|
112
|
+
low: 1.0e-6
|
|
113
|
+
high: 1.0e-2
|
|
114
|
+
batch_size:
|
|
115
|
+
type: "choice"
|
|
116
|
+
values: [32, 64, 128]
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""LangGraph state graph for the tuning agent.
|
|
2
|
+
|
|
3
|
+
Defines the workflow: analyze → decide → tune → evaluate → (loop or end)
|
|
4
|
+
|
|
5
|
+
Uses LangGraph for orchestration and checkpointing, with Pydantic AI
|
|
6
|
+
agents (GLM 5.1 via Ollama) as the LLM reasoning layer inside nodes.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from langgraph.graph import END, StateGraph
|
|
15
|
+
|
|
16
|
+
from sentimentizer import new_logger
|
|
17
|
+
from sentimentizer.agent.loader import load_agent_config
|
|
18
|
+
from sentimentizer.agent.models import AgentRunResult
|
|
19
|
+
from sentimentizer.agent.nodes import analyze, decide, evaluate, tune
|
|
20
|
+
from sentimentizer.agent.state import AgentState
|
|
21
|
+
from sentimentizer.config import DEFAULT_LOG_LEVEL
|
|
22
|
+
|
|
23
|
+
logger = new_logger(DEFAULT_LOG_LEVEL)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def should_continue(state: AgentState) -> str:
|
|
27
|
+
"""Conditional edge: decide whether to continue or end the loop.
|
|
28
|
+
|
|
29
|
+
Returns 'analyze' to continue iterating, or 'end' to stop.
|
|
30
|
+
"""
|
|
31
|
+
if state.get("converged", False):
|
|
32
|
+
return "end"
|
|
33
|
+
|
|
34
|
+
# Safety check: if iteration exceeds a hard limit, stop
|
|
35
|
+
iteration = state.get("iteration", 0)
|
|
36
|
+
config_path = state.get("agent_config_path")
|
|
37
|
+
try:
|
|
38
|
+
agent_config, _ = load_agent_config(config_path)
|
|
39
|
+
hard_limit = agent_config.max_iterations * 2 # allow some extra
|
|
40
|
+
except Exception:
|
|
41
|
+
hard_limit = 20
|
|
42
|
+
|
|
43
|
+
if iteration >= hard_limit:
|
|
44
|
+
logger.info("hard_limit_reached", iteration=iteration, limit=hard_limit)
|
|
45
|
+
return "end"
|
|
46
|
+
|
|
47
|
+
return "analyze"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def build_graph(
|
|
51
|
+
model_type: str = "rnn",
|
|
52
|
+
config_path: str | Path | None = None,
|
|
53
|
+
) -> StateGraph:
|
|
54
|
+
"""Build the LangGraph tuning agent graph.
|
|
55
|
+
|
|
56
|
+
The graph has four nodes in a loop:
|
|
57
|
+
1. analyze — Pydantic AI analysis agent examines metrics
|
|
58
|
+
2. decide — Pydantic AI strategy agent chooses next action
|
|
59
|
+
3. tune — Ray Tune + Optuna executes the search
|
|
60
|
+
4. evaluate — Check convergence, update best results
|
|
61
|
+
|
|
62
|
+
After evaluate, a conditional edge either loops back to
|
|
63
|
+
analyze or goes to END.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
model_type: Model to tune ('rnn', 'encoder', 'decoder').
|
|
67
|
+
config_path: Path to agent config YAML.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Compiled StateGraph ready to run.
|
|
71
|
+
"""
|
|
72
|
+
graph = StateGraph(AgentState)
|
|
73
|
+
|
|
74
|
+
# Add nodes
|
|
75
|
+
graph.add_node("analyze", analyze)
|
|
76
|
+
graph.add_node("decide", decide)
|
|
77
|
+
graph.add_node("tune", tune)
|
|
78
|
+
graph.add_node("evaluate", evaluate)
|
|
79
|
+
|
|
80
|
+
# Define edges
|
|
81
|
+
graph.set_entry_point("analyze")
|
|
82
|
+
graph.add_edge("analyze", "decide")
|
|
83
|
+
graph.add_edge("decide", "tune")
|
|
84
|
+
graph.add_edge("tune", "evaluate")
|
|
85
|
+
|
|
86
|
+
# Conditional edge after evaluate: loop back or end
|
|
87
|
+
graph.add_conditional_edges(
|
|
88
|
+
"evaluate",
|
|
89
|
+
should_continue,
|
|
90
|
+
{"analyze": "analyze", "end": END},
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
return graph.compile()
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def create_initial_state(
|
|
97
|
+
model_type: str = "rnn",
|
|
98
|
+
config_path: str | Path | None = None,
|
|
99
|
+
) -> dict[str, Any]:
|
|
100
|
+
"""Create the initial AgentState for a new tuning run.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
model_type: Model to tune ('rnn', 'encoder', 'decoder').
|
|
104
|
+
config_path: Path to agent config YAML.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
Initial state dict for the graph.
|
|
108
|
+
"""
|
|
109
|
+
return {
|
|
110
|
+
"iteration": 0,
|
|
111
|
+
"model_type": model_type,
|
|
112
|
+
"history": [],
|
|
113
|
+
"best_config": {},
|
|
114
|
+
"best_accuracy": 0.0,
|
|
115
|
+
"best_loss": float("inf"),
|
|
116
|
+
"search_space_overrides": {},
|
|
117
|
+
"agent_config_path": str(config_path) if config_path else None,
|
|
118
|
+
"converged": False,
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
async def run_agent_tuning(
|
|
123
|
+
model_type: str = "rnn",
|
|
124
|
+
config_path: str | Path | None = None,
|
|
125
|
+
) -> AgentRunResult:
|
|
126
|
+
"""Run the complete agent tuning loop.
|
|
127
|
+
|
|
128
|
+
This is the main entry point for the tuning agent. It:
|
|
129
|
+
1. Builds the LangGraph state graph
|
|
130
|
+
2. Creates initial state
|
|
131
|
+
3. Runs the graph to completion
|
|
132
|
+
4. Returns the best configuration found
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
model_type: Model to tune ('rnn', 'encoder', 'decoder').
|
|
136
|
+
config_path: Path to agent config YAML (uses default if None).
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
AgentRunResult with the best configuration and full history.
|
|
140
|
+
"""
|
|
141
|
+
graph = build_graph(model_type, config_path)
|
|
142
|
+
initial_state = create_initial_state(model_type, config_path)
|
|
143
|
+
|
|
144
|
+
logger.info(
|
|
145
|
+
"starting_agent_tuning",
|
|
146
|
+
model_type=model_type,
|
|
147
|
+
config_path=str(config_path) if config_path else "default",
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Run the graph
|
|
151
|
+
final_state = await graph.ainvoke(initial_state)
|
|
152
|
+
|
|
153
|
+
# Extract final result
|
|
154
|
+
final_result = final_state.get("final_result")
|
|
155
|
+
if final_result is None:
|
|
156
|
+
# Build result from state if graph didn't set it explicitly
|
|
157
|
+
final_result = AgentRunResult(
|
|
158
|
+
best_config=final_state.get("best_config", {}),
|
|
159
|
+
best_accuracy=final_state.get("best_accuracy", 0.0),
|
|
160
|
+
best_loss=final_state.get("best_loss", float("inf")),
|
|
161
|
+
iterations_completed=final_state.get("iteration", 0),
|
|
162
|
+
converged=final_state.get("converged", False),
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
logger.info(
|
|
166
|
+
"agent_tuning_complete",
|
|
167
|
+
best_accuracy=final_result.best_accuracy,
|
|
168
|
+
best_loss=final_result.best_loss,
|
|
169
|
+
iterations=final_result.iterations_completed,
|
|
170
|
+
converged=final_result.converged,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
return final_result
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Load agent and tuner configuration from YAML.
|
|
2
|
+
|
|
3
|
+
Provides dataclass-backed configuration loading with sensible defaults,
|
|
4
|
+
so the YAML file only needs to override what differs from defaults.
|
|
5
|
+
|
|
6
|
+
TunerConfig and load_search_space live in sentimentizer.tuner (the
|
|
7
|
+
standalone tuning module). This module re-exports them for convenience
|
|
8
|
+
and adds AgentConfig, which is agent-specific.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
import yaml
|
|
18
|
+
|
|
19
|
+
from sentimentizer.tuner import TunerConfig
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class CheckpointConfig:
|
|
24
|
+
"""LangGraph checkpointing configuration."""
|
|
25
|
+
|
|
26
|
+
enabled: bool = True
|
|
27
|
+
db_path: str = "agent_checkpoints.db"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class AgentConfig:
|
|
32
|
+
"""Configuration for the Pydantic AI tuning agent.
|
|
33
|
+
|
|
34
|
+
Connects to GLM 5.1 (or any model) via Ollama's OpenAI-compatible API.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
model_name: str = "glm-5.1:cloud"
|
|
38
|
+
ollama_base_url: str = "http://localhost:11434/v1"
|
|
39
|
+
temperature: float = 0.3
|
|
40
|
+
max_tokens: int = 2048
|
|
41
|
+
max_iterations: int = 5
|
|
42
|
+
convergence_threshold: float = 0.005
|
|
43
|
+
initial_search_strategy: str = "bayesian"
|
|
44
|
+
checkpointing: CheckpointConfig = field(default_factory=CheckpointConfig)
|
|
45
|
+
human_in_the_loop: bool = False
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _get_default_config_path() -> Path:
|
|
49
|
+
"""Return path to the default config.yaml bundled with the package."""
|
|
50
|
+
return Path(__file__).parent / "config.yaml"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _config_dict_to_agent_config(agent_dict: dict[str, Any]) -> AgentConfig:
|
|
54
|
+
"""Convert the agent section of the YAML dict into an AgentConfig."""
|
|
55
|
+
cp_dict = agent_dict.pop("checkpointing", {})
|
|
56
|
+
checkpoint_cfg = CheckpointConfig(**cp_dict)
|
|
57
|
+
return AgentConfig(checkpointing=checkpoint_cfg, **agent_dict)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def load_agent_config(
|
|
61
|
+
path: str | Path | None = None,
|
|
62
|
+
) -> tuple[AgentConfig, TunerConfig]:
|
|
63
|
+
"""Load agent and tuner configuration from a YAML file.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
path: Path to YAML config file. If None, uses the default
|
|
67
|
+
config.yaml bundled with the package.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Tuple of (AgentConfig, TunerConfig) with all settings populated.
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
FileNotFoundError: If the config file doesn't exist.
|
|
74
|
+
"""
|
|
75
|
+
path = _get_default_config_path() if path is None else Path(path)
|
|
76
|
+
|
|
77
|
+
if not path.exists():
|
|
78
|
+
raise FileNotFoundError(f"Agent config file not found: {path}")
|
|
79
|
+
|
|
80
|
+
with open(path) as f:
|
|
81
|
+
raw = yaml.safe_load(f)
|
|
82
|
+
|
|
83
|
+
agent_cfg = _config_dict_to_agent_config(raw.get("agent", {}))
|
|
84
|
+
tuner_cfg = TunerConfig(**raw.get("tuner", {}))
|
|
85
|
+
|
|
86
|
+
return agent_cfg, tuner_cfg
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# Allow overriding config path via environment variable
|
|
90
|
+
CONFIG_PATH_ENV = "SENTIMENTIZER_AGENT_CONFIG"
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""Pydantic models for the tuning agent's structured input/output.
|
|
2
|
+
|
|
3
|
+
These models validate the LLM's responses and the tuning results,
|
|
4
|
+
ensuring that only valid configurations are passed to Ray Tune.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import Literal
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel, Field
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SearchSpaceParam(BaseModel):
|
|
15
|
+
"""A single parameter's search space specification.
|
|
16
|
+
|
|
17
|
+
Produced by the LLM agent when it decides to narrow/widen
|
|
18
|
+
the search space for a hyperparameter.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
type: Literal["loguniform", "uniform", "choice", "randint"]
|
|
22
|
+
low: float | None = None
|
|
23
|
+
high: float | None = None
|
|
24
|
+
values: list[int] | list[float] | None = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TuningDecision(BaseModel):
|
|
28
|
+
"""Structured output from the strategy agent.
|
|
29
|
+
|
|
30
|
+
The LLM produces this after analyzing training metrics.
|
|
31
|
+
Pydantic validates it, rejecting hallucinated or invalid configs.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
reasoning: str = Field(
|
|
35
|
+
...,
|
|
36
|
+
description="Brief explanation of why this strategy was chosen",
|
|
37
|
+
)
|
|
38
|
+
strategy: Literal["widen", "narrow", "change_focus", "increase_epochs", "stop"] = Field(
|
|
39
|
+
...,
|
|
40
|
+
description="The tuning strategy to apply next",
|
|
41
|
+
)
|
|
42
|
+
search_space: dict[str, SearchSpaceParam] = Field(
|
|
43
|
+
...,
|
|
44
|
+
description="Updated search space parameters for Ray Tune",
|
|
45
|
+
)
|
|
46
|
+
num_samples: int = Field(
|
|
47
|
+
default=20,
|
|
48
|
+
ge=5,
|
|
49
|
+
le=100,
|
|
50
|
+
description="Number of trials for Ray Tune to run",
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class TuningResult(BaseModel):
|
|
55
|
+
"""Result from a single Ray Tune tuning run.
|
|
56
|
+
|
|
57
|
+
Fed back to the LLM agent for analysis in the next iteration.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
best_accuracy: float = Field(..., description="Best validation accuracy achieved")
|
|
61
|
+
best_loss: float = Field(..., description="Best validation loss achieved")
|
|
62
|
+
best_config: dict[str, float | int] = Field(
|
|
63
|
+
..., description="Best hyperparameter configuration found"
|
|
64
|
+
)
|
|
65
|
+
trial_count: int = Field(..., description="Number of trials completed")
|
|
66
|
+
improvement_over_last: float = Field(
|
|
67
|
+
default=0.0, description="Accuracy improvement vs. previous best"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class AnalysisResult(BaseModel):
|
|
72
|
+
"""Structured output from the analysis agent.
|
|
73
|
+
|
|
74
|
+
The LLM produces this after examining training metrics and history.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
summary: str = Field(..., description="Brief summary of the training results")
|
|
78
|
+
overfitting: bool = Field(
|
|
79
|
+
default=False, description="Whether the model appears to be overfitting"
|
|
80
|
+
)
|
|
81
|
+
underfitting: bool = Field(
|
|
82
|
+
default=False, description="Whether the model appears to be underfitting"
|
|
83
|
+
)
|
|
84
|
+
lr_status: Literal["too_high", "too_low", "appropriate", "unclear"] = Field(
|
|
85
|
+
default="unclear", description="Assessment of the learning rate"
|
|
86
|
+
)
|
|
87
|
+
suggested_focus: list[str] = Field(
|
|
88
|
+
default_factory=list,
|
|
89
|
+
description="Parameters to focus on in the next iteration",
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class AgentRunResult(BaseModel):
|
|
94
|
+
"""Final result from the complete agent tuning loop.
|
|
95
|
+
|
|
96
|
+
Returned when the agent converges or reaches max iterations.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
best_config: dict[str, float | int]
|
|
100
|
+
best_accuracy: float
|
|
101
|
+
best_loss: float
|
|
102
|
+
iterations_completed: int
|
|
103
|
+
converged: bool
|
|
104
|
+
history: list[TuningResult] = Field(default_factory=list)
|