agentensor 0.0.4__tar.gz → 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentensor-0.1.0/.bumpversion.cfg +18 -0
- {agentensor-0.0.4 → agentensor-0.1.0}/.gitignore +40 -10
- agentensor-0.1.0/PKG-INFO +40 -0
- {agentensor-0.0.4 → agentensor-0.1.0}/pyproject.toml +21 -20
- agentensor-0.1.0/src/agentensor/__init__.py +8 -0
- {agentensor-0.0.4 → agentensor-0.1.0/src}/agentensor/module.py +4 -1
- agentensor-0.1.0/src/agentensor/optim.py +109 -0
- agentensor-0.1.0/src/agentensor/py.typed +0 -0
- {agentensor-0.0.4 → agentensor-0.1.0/src}/agentensor/tensor.py +14 -8
- agentensor-0.1.0/src/agentensor/train.py +280 -0
- agentensor-0.0.4/.bumpversion.cfg +0 -12
- agentensor-0.0.4/.github/workflows/after-ci.yml +0 -50
- agentensor-0.0.4/.github/workflows/ci.yml +0 -103
- agentensor-0.0.4/.pre-commit-config.yaml +0 -26
- agentensor-0.0.4/.python-version +0 -1
- agentensor-0.0.4/LICENSE +0 -21
- agentensor-0.0.4/Makefile +0 -15
- agentensor-0.0.4/PKG-INFO +0 -45
- agentensor-0.0.4/agentensor/__init__.py +0 -1
- agentensor-0.0.4/agentensor/optim.py +0 -57
- agentensor-0.0.4/agentensor/train.py +0 -93
- agentensor-0.0.4/examples/evaluate.py +0 -154
- agentensor-0.0.4/examples/train.py +0 -118
- agentensor-0.0.4/mkdocs.yml +0 -1
- agentensor-0.0.4/tests/test_loss.py +0 -96
- agentensor-0.0.4/tests/test_module.py +0 -119
- agentensor-0.0.4/tests/test_optim.py +0 -130
- agentensor-0.0.4/tests/test_tensor.py +0 -156
- agentensor-0.0.4/tests/test_train.py +0 -266
- agentensor-0.0.4/uv.lock +0 -3431
- {agentensor-0.0.4 → agentensor-0.1.0}/README.md +0 -0
- {agentensor-0.0.4 → agentensor-0.1.0/src}/agentensor/loss.py +0 -0
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
[bumpversion]
|
|
2
|
+
current_version = 0.1.0
|
|
3
|
+
commit = True
|
|
4
|
+
tag = True
|
|
5
|
+
tag_name = agentensor-v{new_version}
|
|
6
|
+
message = Bump agentensor version to {new_version}
|
|
7
|
+
|
|
8
|
+
[bumpversion:file:pyproject.toml]
|
|
9
|
+
search = version = "{current_version}"
|
|
10
|
+
replace = version = "{new_version}"
|
|
11
|
+
|
|
12
|
+
[bumpversion:file:../../uv.lock]
|
|
13
|
+
search = [[package]]
|
|
14
|
+
name = "agentensor"
|
|
15
|
+
version = "{current_version}"
|
|
16
|
+
replace = [[package]]
|
|
17
|
+
name = "agentensor"
|
|
18
|
+
version = "{new_version}"
|
|
@@ -10,11 +10,10 @@ __pycache__/
|
|
|
10
10
|
.Python
|
|
11
11
|
build/
|
|
12
12
|
develop-eggs/
|
|
13
|
-
dist/
|
|
14
13
|
downloads/
|
|
15
14
|
eggs/
|
|
16
15
|
.eggs/
|
|
17
|
-
lib/
|
|
16
|
+
/lib/
|
|
18
17
|
lib64/
|
|
19
18
|
parts/
|
|
20
19
|
sdist/
|
|
@@ -57,7 +56,6 @@ cover/
|
|
|
57
56
|
*.pot
|
|
58
57
|
|
|
59
58
|
# Django stuff:
|
|
60
|
-
*.log
|
|
61
59
|
local_settings.py
|
|
62
60
|
db.sqlite3
|
|
63
61
|
db.sqlite3-journal
|
|
@@ -153,12 +151,44 @@ dmypy.json
|
|
|
153
151
|
# Cython debug symbols
|
|
154
152
|
cython_debug/
|
|
155
153
|
|
|
156
|
-
#
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
#
|
|
160
|
-
|
|
161
|
-
|
|
154
|
+
# Dependencies
|
|
155
|
+
node_modules/
|
|
156
|
+
|
|
157
|
+
# Logs
|
|
158
|
+
logs
|
|
159
|
+
*.log
|
|
160
|
+
npm-debug.log*
|
|
161
|
+
yarn-debug.log*
|
|
162
|
+
yarn-error.log*
|
|
163
|
+
pnpm-debug.log*
|
|
164
|
+
lerna-debug.log*
|
|
162
165
|
|
|
163
|
-
#
|
|
166
|
+
# Build
|
|
167
|
+
dist/
|
|
168
|
+
dist-ssr/
|
|
169
|
+
*.local
|
|
170
|
+
|
|
171
|
+
# Editor directories and files
|
|
172
|
+
.vscode/*
|
|
173
|
+
!.vscode/extensions.json
|
|
174
|
+
.idea/
|
|
175
|
+
*.suo
|
|
176
|
+
*.ntvs*
|
|
177
|
+
*.njsproj
|
|
178
|
+
*.sln
|
|
179
|
+
*.sw?
|
|
180
|
+
|
|
181
|
+
# OS generated files
|
|
182
|
+
.DS_Store
|
|
164
183
|
**/.DS_Store
|
|
184
|
+
|
|
185
|
+
# LangGraph
|
|
186
|
+
checkpoints.sqlite
|
|
187
|
+
|
|
188
|
+
# Orcheo
|
|
189
|
+
.orcheo/
|
|
190
|
+
**/workflow_config.json
|
|
191
|
+
|
|
192
|
+
# Miscellaneous
|
|
193
|
+
homepage/
|
|
194
|
+
orcheo-paper/
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: agentensor
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Agent prompt tensors, modules, and optimizers for Orcheo workflows
|
|
5
|
+
Requires-Python: >=3.12
|
|
6
|
+
Requires-Dist: datasets>=3.5.0
|
|
7
|
+
Requires-Dist: langchain-ollama>=0.3.3
|
|
8
|
+
Requires-Dist: langchain-openai>=0.3.18
|
|
9
|
+
Requires-Dist: langchain>=1.1.3
|
|
10
|
+
Requires-Dist: langgraph>=1.0.5
|
|
11
|
+
Requires-Dist: logfire>=3.14.0
|
|
12
|
+
Requires-Dist: pydantic-ai>=0.2.4
|
|
13
|
+
Provides-Extra: dev
|
|
14
|
+
Requires-Dist: bump2version; extra == 'dev'
|
|
15
|
+
Requires-Dist: diff-cover>=9.2.4; extra == 'dev'
|
|
16
|
+
Requires-Dist: mypy>=1.11.2; extra == 'dev'
|
|
17
|
+
Requires-Dist: pre-commit; extra == 'dev'
|
|
18
|
+
Requires-Dist: pytest; extra == 'dev'
|
|
19
|
+
Requires-Dist: pytest-asyncio>=0.23.8; extra == 'dev'
|
|
20
|
+
Requires-Dist: pytest-cov>=4.1.0; extra == 'dev'
|
|
21
|
+
Requires-Dist: ruff>=0.11.3; extra == 'dev'
|
|
22
|
+
Requires-Dist: smokeshow>=0.5.0; extra == 'dev'
|
|
23
|
+
Requires-Dist: types-requests; extra == 'dev'
|
|
24
|
+
Provides-Extra: docs
|
|
25
|
+
Requires-Dist: mkdocs; extra == 'docs'
|
|
26
|
+
Requires-Dist: mkdocs-gen-files; extra == 'docs'
|
|
27
|
+
Requires-Dist: mkdocs-jupyter; extra == 'docs'
|
|
28
|
+
Requires-Dist: mkdocs-material; extra == 'docs'
|
|
29
|
+
Requires-Dist: mkdocstrings[python]>=0.28.1; extra == 'docs'
|
|
30
|
+
Description-Content-Type: text/markdown
|
|
31
|
+
|
|
32
|
+
# AgenTensor
|
|
33
|
+
|
|
34
|
+
[](https://github.com/ShaojieJiang/agentensor/actions/workflows/ci.yml?query=branch%3Amain)
|
|
35
|
+
[](https://coverage-badge.samuelcolvin.workers.dev/redirect/ShaojieJiang/agentensor)
|
|
36
|
+
[](https://pypi.python.org/pypi/agentensor)
|
|
37
|
+
|
|
38
|
+
## TODO
|
|
39
|
+
|
|
40
|
+
- [ ] Add parameter saving
|
|
@@ -2,11 +2,27 @@
|
|
|
2
2
|
build-backend = "hatchling.build"
|
|
3
3
|
requires = ["hatchling"]
|
|
4
4
|
|
|
5
|
-
[
|
|
5
|
+
[project]
|
|
6
|
+
dependencies = [
|
|
7
|
+
"datasets>=3.5.0",
|
|
8
|
+
"langchain>=1.1.3",
|
|
9
|
+
"langchain-ollama>=0.3.3",
|
|
10
|
+
"langchain-openai>=0.3.18",
|
|
11
|
+
"langgraph>=1.0.5",
|
|
12
|
+
"logfire>=3.14.0",
|
|
13
|
+
"pydantic-ai>=0.2.4"
|
|
14
|
+
]
|
|
15
|
+
description = "Agent prompt tensors, modules, and optimizers for Orcheo workflows"
|
|
16
|
+
name = "agentensor"
|
|
17
|
+
readme = "README.md"
|
|
18
|
+
requires-python = ">=3.12"
|
|
19
|
+
url = "https://github.com/ShaojieJiang/agentensor"
|
|
20
|
+
version = "0.1.0"
|
|
21
|
+
|
|
22
|
+
[project.optional-dependencies]
|
|
6
23
|
dev = [
|
|
7
24
|
"bump2version",
|
|
8
25
|
"diff-cover>=9.2.4",
|
|
9
|
-
"isort",
|
|
10
26
|
"mypy>=1.11.2",
|
|
11
27
|
"pre-commit",
|
|
12
28
|
"pytest",
|
|
@@ -24,24 +40,6 @@ docs = [
|
|
|
24
40
|
"mkdocstrings[python]>=0.28.1"
|
|
25
41
|
]
|
|
26
42
|
|
|
27
|
-
[project]
|
|
28
|
-
dependencies = [
|
|
29
|
-
"datasets>=3.5.0",
|
|
30
|
-
"langchain>=0.3.25",
|
|
31
|
-
"langchain-ollama>=0.3.3",
|
|
32
|
-
"langchain-openai>=0.3.18",
|
|
33
|
-
"langgraph>=0.4.5",
|
|
34
|
-
"logfire>=3.14.0",
|
|
35
|
-
"pydantic-ai>=0.2.4"
|
|
36
|
-
]
|
|
37
|
-
description = "Add your description here"
|
|
38
|
-
license = {file = "LICENSE"}
|
|
39
|
-
name = "agentensor"
|
|
40
|
-
readme = "README.md"
|
|
41
|
-
requires-python = ">=3.12"
|
|
42
|
-
url = "https://github.com/ShaojieJiang/agentensor"
|
|
43
|
-
version = "0.0.4"
|
|
44
|
-
|
|
45
43
|
[tool.coverage.report]
|
|
46
44
|
exclude_lines = [
|
|
47
45
|
"pragma: no cover",
|
|
@@ -55,6 +53,9 @@ branch = true
|
|
|
55
53
|
command_line = "-m pytest"
|
|
56
54
|
source = ["agentensor"]
|
|
57
55
|
|
|
56
|
+
[tool.hatch.build.targets.wheel]
|
|
57
|
+
packages = ["src/agentensor"]
|
|
58
|
+
|
|
58
59
|
[tool.mypy]
|
|
59
60
|
disallow_untyped_defs = true
|
|
60
61
|
ignore_missing_imports = true
|
|
@@ -5,12 +5,15 @@ from typing import Any
|
|
|
5
5
|
from langchain.chat_models import init_chat_model
|
|
6
6
|
from langchain_core.language_models import BaseChatModel
|
|
7
7
|
from langchain_core.messages import HumanMessage
|
|
8
|
-
from langgraph.graph.
|
|
8
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
9
9
|
from pydantic import BaseModel, ConfigDict
|
|
10
10
|
from pydantic_ai.exceptions import UnexpectedModelBehavior
|
|
11
11
|
from agentensor.tensor import TextTensor
|
|
12
12
|
|
|
13
13
|
|
|
14
|
+
CompiledGraph = CompiledStateGraph[Any, Any, Any, Any]
|
|
15
|
+
|
|
16
|
+
|
|
14
17
|
class AgentModule(BaseModel, ABC):
|
|
15
18
|
"""Agent module."""
|
|
16
19
|
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""Optimizer module."""
|
|
2
|
+
|
|
3
|
+
from langchain.agents import create_agent
|
|
4
|
+
from langchain.chat_models import init_chat_model
|
|
5
|
+
from langchain_core.language_models import BaseChatModel
|
|
6
|
+
from langchain_core.messages import HumanMessage
|
|
7
|
+
from langchain_core.runnables import Runnable
|
|
8
|
+
from langgraph.graph import StateGraph
|
|
9
|
+
from agentensor.module import AgentModule
|
|
10
|
+
from agentensor.tensor import TextTensor
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Optimizer:
|
|
14
|
+
"""Optimizer class."""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
graph: StateGraph | None = None,
|
|
19
|
+
model: str | BaseChatModel = "gpt-4o-mini",
|
|
20
|
+
params: list[TextTensor] | None = None,
|
|
21
|
+
) -> None:
|
|
22
|
+
"""Initialize the optimizer."""
|
|
23
|
+
self.params: list[TextTensor] = (
|
|
24
|
+
self._coerce_params(params, source="params") if params is not None else []
|
|
25
|
+
)
|
|
26
|
+
if graph is not None and not params:
|
|
27
|
+
for node in graph.nodes.values():
|
|
28
|
+
runnable = getattr(node, "runnable", None)
|
|
29
|
+
module: AgentModule | None = None
|
|
30
|
+
|
|
31
|
+
if isinstance(runnable, AgentModule):
|
|
32
|
+
module = runnable
|
|
33
|
+
else:
|
|
34
|
+
function = getattr(runnable, "afunc", None)
|
|
35
|
+
if isinstance(function, AgentModule):
|
|
36
|
+
module = function
|
|
37
|
+
else:
|
|
38
|
+
bound_self = getattr(function, "__self__", None)
|
|
39
|
+
if isinstance(bound_self, AgentModule):
|
|
40
|
+
module = bound_self
|
|
41
|
+
|
|
42
|
+
if module is not None:
|
|
43
|
+
self.params.extend(
|
|
44
|
+
self._coerce_params(
|
|
45
|
+
module.get_params(),
|
|
46
|
+
source=f"{module.__class__.__name__}.get_params()",
|
|
47
|
+
)
|
|
48
|
+
)
|
|
49
|
+
continue
|
|
50
|
+
|
|
51
|
+
param_provider = getattr(runnable, "get_params", None)
|
|
52
|
+
if callable(param_provider):
|
|
53
|
+
self.params.extend(
|
|
54
|
+
self._coerce_params(
|
|
55
|
+
param_provider(),
|
|
56
|
+
source=f"{runnable.__class__.__name__}.get_params()",
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
if isinstance(model, str):
|
|
60
|
+
self.model = init_chat_model(model)
|
|
61
|
+
else: # pragma: no cover
|
|
62
|
+
self.model = model
|
|
63
|
+
|
|
64
|
+
def step(self) -> None:
|
|
65
|
+
"""Step the optimizer."""
|
|
66
|
+
for param in self.params:
|
|
67
|
+
if not param.text_grad:
|
|
68
|
+
continue
|
|
69
|
+
param.text = self.optimize(param.text, param.text_grad)
|
|
70
|
+
|
|
71
|
+
def zero_grad(self) -> None:
|
|
72
|
+
"""Zero the gradients."""
|
|
73
|
+
for param in self.params:
|
|
74
|
+
param.zero_grad()
|
|
75
|
+
|
|
76
|
+
def optimize(self, text: str, grad: str) -> str:
|
|
77
|
+
"""Optimize the text."""
|
|
78
|
+
result = self.agent.invoke(
|
|
79
|
+
{"messages": [HumanMessage(content=f"Feedback: {grad}\nText: {text}")]}
|
|
80
|
+
)
|
|
81
|
+
return result["messages"][-1].content
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def agent(self) -> Runnable:
|
|
85
|
+
"""Get the agent."""
|
|
86
|
+
return create_agent(
|
|
87
|
+
self.model,
|
|
88
|
+
tools=[],
|
|
89
|
+
system_prompt="Rewrite the system prompt given the feedback.",
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
@staticmethod
|
|
93
|
+
def _coerce_params(params: object, *, source: str) -> list[TextTensor]:
|
|
94
|
+
"""Normalize parameters and raise if they are incompatible."""
|
|
95
|
+
if isinstance(params, TextTensor):
|
|
96
|
+
return [params]
|
|
97
|
+
if isinstance(params, list | tuple):
|
|
98
|
+
if not params:
|
|
99
|
+
return []
|
|
100
|
+
invalid = [param for param in params if not isinstance(param, TextTensor)]
|
|
101
|
+
if invalid:
|
|
102
|
+
raise TypeError(
|
|
103
|
+
f"{source} must contain only TextTensor instances, got "
|
|
104
|
+
f"{type(invalid[0]).__name__}."
|
|
105
|
+
)
|
|
106
|
+
return list(params)
|
|
107
|
+
raise TypeError(
|
|
108
|
+
f"{source} must return a TextTensor or list of TextTensor objects."
|
|
109
|
+
)
|
|
File without changes
|
|
@@ -1,11 +1,12 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Text tensor primitives."""
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
|
+
from typing import Any
|
|
5
|
+
from langchain.agents import create_agent
|
|
4
6
|
from langchain.chat_models import init_chat_model
|
|
5
7
|
from langchain_core.language_models import BaseChatModel
|
|
6
8
|
from langchain_core.messages import HumanMessage
|
|
7
|
-
from
|
|
8
|
-
from langgraph.prebuilt import create_react_agent
|
|
9
|
+
from langchain_core.runnables import Runnable
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
class TextTensor:
|
|
@@ -16,15 +17,18 @@ class TextTensor:
|
|
|
16
17
|
text: str,
|
|
17
18
|
parents: list[TextTensor] | None = None,
|
|
18
19
|
requires_grad: bool = False,
|
|
19
|
-
|
|
20
|
+
metadata: dict[str, Any] | None = None,
|
|
21
|
+
model: str | BaseChatModel = "openai:gpt-4o-mini",
|
|
22
|
+
model_kwargs: dict[str, Any] | None = None,
|
|
20
23
|
) -> None:
|
|
21
24
|
"""Initialize a TextTensor."""
|
|
22
25
|
self.text = text
|
|
23
26
|
self.requires_grad = requires_grad
|
|
24
27
|
self.gradients: list[str] = []
|
|
28
|
+
self.metadata = dict(metadata) if metadata is not None else {}
|
|
25
29
|
self.parents: list[TextTensor] = parents or []
|
|
26
30
|
if isinstance(model, str):
|
|
27
|
-
self.model = init_chat_model(model)
|
|
31
|
+
self.model = init_chat_model(model, **(model_kwargs or {}))
|
|
28
32
|
else:
|
|
29
33
|
self.model = model
|
|
30
34
|
|
|
@@ -77,8 +81,10 @@ class TextTensor:
|
|
|
77
81
|
return self.text
|
|
78
82
|
|
|
79
83
|
@property
|
|
80
|
-
def agent(self) ->
|
|
84
|
+
def agent(self) -> Runnable:
|
|
81
85
|
"""Get the agent."""
|
|
82
|
-
return
|
|
83
|
-
self.model,
|
|
86
|
+
return create_agent(
|
|
87
|
+
self.model,
|
|
88
|
+
tools=[],
|
|
89
|
+
system_prompt="Answer the user's question.",
|
|
84
90
|
)
|
|
@@ -0,0 +1,280 @@
|
|
|
1
|
+
"""Trainer."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
import asyncio
|
|
5
|
+
import json
|
|
6
|
+
from collections.abc import Mapping
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
from langchain_core.runnables import RunnableConfig
|
|
9
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
10
|
+
from pydantic_evals import Dataset
|
|
11
|
+
from pydantic_evals.reporting import EvaluationReport
|
|
12
|
+
from agentensor.optim import Optimizer
|
|
13
|
+
from agentensor.tensor import TextTensor
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
CompiledGraph = CompiledStateGraph[Any, Any, Any, Any]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Trainer:
|
|
20
|
+
"""Trainer."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
graph: CompiledGraph,
|
|
25
|
+
graph_recursion_limit: int = 25,
|
|
26
|
+
train_dataset: Dataset[TextTensor, TextTensor, Any] | None = None,
|
|
27
|
+
eval_dataset: Dataset[TextTensor, TextTensor, Any] | None = None,
|
|
28
|
+
test_dataset: Dataset[TextTensor, TextTensor, Any] | None = None,
|
|
29
|
+
optimizer: Optimizer | None = None,
|
|
30
|
+
epochs: int = 10,
|
|
31
|
+
stop_threshold: float = 0.95,
|
|
32
|
+
):
|
|
33
|
+
"""Initialize the trainer."""
|
|
34
|
+
self.graph = graph
|
|
35
|
+
self.graph_recursion_limit = graph_recursion_limit
|
|
36
|
+
self.optimizer = optimizer
|
|
37
|
+
self.epochs = epochs
|
|
38
|
+
self.stop_threshold = stop_threshold
|
|
39
|
+
self.train_dataset = train_dataset
|
|
40
|
+
self.eval_dataset = eval_dataset
|
|
41
|
+
self.test_dataset = test_dataset
|
|
42
|
+
|
|
43
|
+
async def forward(self, x: TextTensor) -> TextTensor:
|
|
44
|
+
"""Forward the graph."""
|
|
45
|
+
result = await self.graph.ainvoke(
|
|
46
|
+
{"output": x}, {"recursion_limit": self.graph_recursion_limit}
|
|
47
|
+
)
|
|
48
|
+
return result["output"]
|
|
49
|
+
|
|
50
|
+
def train(self) -> None:
|
|
51
|
+
"""Train the graph."""
|
|
52
|
+
self._require_dataset("train")
|
|
53
|
+
optimizer = self._require_optimizer()
|
|
54
|
+
for i in range(self.epochs):
|
|
55
|
+
report = self.evaluate("train")
|
|
56
|
+
report.print(
|
|
57
|
+
include_input=True, include_output=True, include_durations=True
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Backward those failed cases
|
|
61
|
+
for case in report.cases:
|
|
62
|
+
losses: list[str] = []
|
|
63
|
+
for evaluator in case.assertions.values():
|
|
64
|
+
if not evaluator.value:
|
|
65
|
+
reason = getattr(evaluator, "reason", None)
|
|
66
|
+
if reason is None or (
|
|
67
|
+
isinstance(reason, str) and not reason.strip()
|
|
68
|
+
):
|
|
69
|
+
losses.append("Evaluation failed without a reason.")
|
|
70
|
+
else:
|
|
71
|
+
losses.append(str(reason))
|
|
72
|
+
if losses:
|
|
73
|
+
case.output.backward(" ".join(losses))
|
|
74
|
+
|
|
75
|
+
optimizer.step()
|
|
76
|
+
optimizer.zero_grad()
|
|
77
|
+
|
|
78
|
+
performance = report.averages()
|
|
79
|
+
assertions = None if performance is None else performance.assertions
|
|
80
|
+
self.after_epoch(i, report)
|
|
81
|
+
if assertions is not None and assertions >= self.stop_threshold:
|
|
82
|
+
print("Optimization complete.")
|
|
83
|
+
break
|
|
84
|
+
|
|
85
|
+
def evaluate(
|
|
86
|
+
self,
|
|
87
|
+
data_split: Literal["train", "eval", "test"] = "eval",
|
|
88
|
+
limit_cases: int | None = None,
|
|
89
|
+
max_concurrency: int | None = None,
|
|
90
|
+
progress: bool = True,
|
|
91
|
+
) -> EvaluationReport:
|
|
92
|
+
"""Evaluate the graph."""
|
|
93
|
+
dataset = self._require_dataset(data_split)
|
|
94
|
+
if limit_cases:
|
|
95
|
+
limited_cases = dataset.cases[:limit_cases]
|
|
96
|
+
dataset = Dataset(cases=limited_cases, evaluators=dataset.evaluators)
|
|
97
|
+
report = dataset.evaluate_sync(
|
|
98
|
+
self.forward,
|
|
99
|
+
max_concurrency=max_concurrency,
|
|
100
|
+
progress=progress,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return report
|
|
104
|
+
|
|
105
|
+
def test(self, limit_cases: int | None = None) -> None:
|
|
106
|
+
"""Test the graph."""
|
|
107
|
+
report = self.evaluate("test", limit_cases=limit_cases)
|
|
108
|
+
report.print(include_input=True, include_output=True, include_durations=True)
|
|
109
|
+
|
|
110
|
+
def after_epoch(self, epoch_index: int, report: EvaluationReport) -> None:
|
|
111
|
+
"""Optional hook for subclasses to record state."""
|
|
112
|
+
return None
|
|
113
|
+
|
|
114
|
+
def _require_dataset(self, data_split: str) -> Dataset[Any, Any, Any]:
|
|
115
|
+
"""Return the dataset for a split or raise a descriptive error."""
|
|
116
|
+
dataset = getattr(self, f"{data_split}_dataset", None)
|
|
117
|
+
if dataset is None:
|
|
118
|
+
raise ValueError(f"{data_split} dataset is required")
|
|
119
|
+
return dataset
|
|
120
|
+
|
|
121
|
+
def _require_optimizer(self) -> Optimizer:
|
|
122
|
+
"""Return the optimizer or raise a descriptive error."""
|
|
123
|
+
if self.optimizer is None:
|
|
124
|
+
raise ValueError("Optimizer is required")
|
|
125
|
+
return self.optimizer
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class GraphTrainer(Trainer):
|
|
129
|
+
"""Trainer that runs a compiled graph against mapping inputs."""
|
|
130
|
+
|
|
131
|
+
def __init__(
|
|
132
|
+
self,
|
|
133
|
+
*,
|
|
134
|
+
graph: CompiledGraph,
|
|
135
|
+
dataset: Dataset[Any, Any, Any],
|
|
136
|
+
optimizer: Optimizer | None = None,
|
|
137
|
+
epochs: int,
|
|
138
|
+
runtime_prompts: Mapping[str, TextTensor],
|
|
139
|
+
base_state: Mapping[str, Any] | None = None,
|
|
140
|
+
graph_config: RunnableConfig | None = None,
|
|
141
|
+
max_concurrency: int = 1,
|
|
142
|
+
case_timeout: int = 30,
|
|
143
|
+
stop_threshold: float = 2.0,
|
|
144
|
+
script_format: bool = False,
|
|
145
|
+
) -> None:
|
|
146
|
+
"""Initialize the graph trainer."""
|
|
147
|
+
super().__init__(
|
|
148
|
+
graph=graph,
|
|
149
|
+
train_dataset=dataset,
|
|
150
|
+
eval_dataset=dataset,
|
|
151
|
+
optimizer=optimizer,
|
|
152
|
+
epochs=epochs,
|
|
153
|
+
stop_threshold=stop_threshold,
|
|
154
|
+
)
|
|
155
|
+
self.runtime_prompts = runtime_prompts
|
|
156
|
+
self.base_state = base_state or {}
|
|
157
|
+
self.graph_config = graph_config
|
|
158
|
+
self.max_concurrency = max_concurrency
|
|
159
|
+
self.case_timeout = case_timeout
|
|
160
|
+
self.reports: list[EvaluationReport] = []
|
|
161
|
+
self.script_format = script_format
|
|
162
|
+
self.prompt_history: list[dict[str, str]] = []
|
|
163
|
+
|
|
164
|
+
async def forward(self, case_inputs: Mapping[str, Any]) -> TextTensor: # type: ignore[override]
|
|
165
|
+
"""Execute the compiled graph and return a tensor with the raw payload."""
|
|
166
|
+
merged_inputs = self._merge_inputs(case_inputs)
|
|
167
|
+
case_state = self._build_case_state(merged_inputs)
|
|
168
|
+
output_state = await asyncio.wait_for(
|
|
169
|
+
self.graph.ainvoke(case_state, config=self.graph_config),
|
|
170
|
+
timeout=self.case_timeout,
|
|
171
|
+
)
|
|
172
|
+
output_payload = self._extract_output(output_state)
|
|
173
|
+
parents = [
|
|
174
|
+
prompt for prompt in self.runtime_prompts.values() if prompt.requires_grad
|
|
175
|
+
]
|
|
176
|
+
tensor = TextTensor(
|
|
177
|
+
text=self._stringify_output(output_payload),
|
|
178
|
+
requires_grad=True,
|
|
179
|
+
parents=parents,
|
|
180
|
+
metadata={"payload": output_payload},
|
|
181
|
+
)
|
|
182
|
+
return tensor
|
|
183
|
+
|
|
184
|
+
def evaluate(
|
|
185
|
+
self,
|
|
186
|
+
data_split: Literal["train", "eval", "test"] = "eval",
|
|
187
|
+
limit_cases: int | None = None,
|
|
188
|
+
max_concurrency: int | None = None,
|
|
189
|
+
progress: bool = False,
|
|
190
|
+
) -> EvaluationReport:
|
|
191
|
+
"""Run evaluation without rendering progress to stdout."""
|
|
192
|
+
dataset = self._require_dataset(data_split)
|
|
193
|
+
cases = dataset.cases
|
|
194
|
+
if limit_cases is not None:
|
|
195
|
+
cases = cases[:limit_cases]
|
|
196
|
+
dataset = Dataset(cases=cases, evaluators=dataset.evaluators)
|
|
197
|
+
report = dataset.evaluate_sync(
|
|
198
|
+
self.forward,
|
|
199
|
+
max_concurrency=max_concurrency or self.max_concurrency,
|
|
200
|
+
progress=progress,
|
|
201
|
+
)
|
|
202
|
+
self.reports.append(report)
|
|
203
|
+
return report
|
|
204
|
+
|
|
205
|
+
def _merge_inputs(self, inputs: Mapping[str, Any]) -> dict[str, Any]:
|
|
206
|
+
merged: dict[str, Any] = {}
|
|
207
|
+
if isinstance(self.base_state, Mapping):
|
|
208
|
+
base_inputs = self.base_state.get("inputs", self.base_state)
|
|
209
|
+
if isinstance(base_inputs, Mapping): # pragma: no branch
|
|
210
|
+
merged.update(base_inputs)
|
|
211
|
+
merged.update(inputs)
|
|
212
|
+
return merged
|
|
213
|
+
|
|
214
|
+
def _build_case_state(self, inputs: Mapping[str, Any]) -> dict[str, Any]:
|
|
215
|
+
runtime_config = (
|
|
216
|
+
dict(self.base_state.get("config", {}))
|
|
217
|
+
if isinstance(self.base_state, Mapping)
|
|
218
|
+
else {}
|
|
219
|
+
)
|
|
220
|
+
if self.script_format:
|
|
221
|
+
state = dict(inputs)
|
|
222
|
+
state["config"] = runtime_config | {"prompts": self.runtime_prompts}
|
|
223
|
+
return state
|
|
224
|
+
return {
|
|
225
|
+
"messages": [],
|
|
226
|
+
"results": {},
|
|
227
|
+
"inputs": dict(inputs),
|
|
228
|
+
"structured_response": None,
|
|
229
|
+
"config": runtime_config | {"prompts": self.runtime_prompts},
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
@staticmethod
|
|
233
|
+
def _stringify_output(output: Any) -> str:
|
|
234
|
+
if isinstance(output, str):
|
|
235
|
+
return output
|
|
236
|
+
try:
|
|
237
|
+
return json.dumps(output)
|
|
238
|
+
except TypeError:
|
|
239
|
+
return str(output)
|
|
240
|
+
|
|
241
|
+
@staticmethod
|
|
242
|
+
def _extract_output(output_state: Any) -> Any:
|
|
243
|
+
if isinstance(output_state, Mapping): # pragma: no branch
|
|
244
|
+
results = output_state.get("results")
|
|
245
|
+
if isinstance(results, Mapping) and results:
|
|
246
|
+
return results
|
|
247
|
+
if "output" in output_state:
|
|
248
|
+
return output_state["output"]
|
|
249
|
+
message_output = GraphTrainer._extract_message_output(
|
|
250
|
+
output_state.get("messages")
|
|
251
|
+
)
|
|
252
|
+
if message_output is not None:
|
|
253
|
+
return message_output
|
|
254
|
+
return output_state
|
|
255
|
+
|
|
256
|
+
@staticmethod
|
|
257
|
+
def _extract_message_output(messages: Any) -> Any | None:
|
|
258
|
+
if not isinstance(messages, list):
|
|
259
|
+
return None
|
|
260
|
+
fallback: Any | None = None
|
|
261
|
+
for message in reversed(messages):
|
|
262
|
+
if isinstance(message, Mapping):
|
|
263
|
+
content = message.get("content")
|
|
264
|
+
role = message.get("role") or message.get("type")
|
|
265
|
+
else:
|
|
266
|
+
content = getattr(message, "content", None)
|
|
267
|
+
role = getattr(message, "role", None) or getattr(message, "type", None)
|
|
268
|
+
if role in {"assistant", "ai"}:
|
|
269
|
+
return content
|
|
270
|
+
if fallback is None and content is not None: # pragma: no branch
|
|
271
|
+
if not (isinstance(content, str) and not content.strip()):
|
|
272
|
+
fallback = content
|
|
273
|
+
return fallback
|
|
274
|
+
|
|
275
|
+
def after_epoch(self, epoch_index: int, report: EvaluationReport) -> None:
|
|
276
|
+
"""Record prompt snapshots after each optimizer step."""
|
|
277
|
+
snapshot: dict[str, str] = {
|
|
278
|
+
name: tensor.text for name, tensor in self.runtime_prompts.items()
|
|
279
|
+
}
|
|
280
|
+
self.prompt_history.append(snapshot)
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
[bumpversion]
|
|
2
|
-
current_version = 0.0.4
|
|
3
|
-
commit = True
|
|
4
|
-
tag = True
|
|
5
|
-
|
|
6
|
-
[bumpversion:file:pyproject.toml]
|
|
7
|
-
search = version = "{current_version}"
|
|
8
|
-
replace = version = "{new_version}"
|
|
9
|
-
|
|
10
|
-
[bumpversion:file:uv.lock]
|
|
11
|
-
search = version = "{current_version}"
|
|
12
|
-
replace = version = "{new_version}"
|