agentensor 0.0.4__tar.gz → 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. agentensor-0.1.0/.bumpversion.cfg +18 -0
  2. {agentensor-0.0.4 → agentensor-0.1.0}/.gitignore +40 -10
  3. agentensor-0.1.0/PKG-INFO +40 -0
  4. {agentensor-0.0.4 → agentensor-0.1.0}/pyproject.toml +21 -20
  5. agentensor-0.1.0/src/agentensor/__init__.py +8 -0
  6. {agentensor-0.0.4 → agentensor-0.1.0/src}/agentensor/module.py +4 -1
  7. agentensor-0.1.0/src/agentensor/optim.py +109 -0
  8. agentensor-0.1.0/src/agentensor/py.typed +0 -0
  9. {agentensor-0.0.4 → agentensor-0.1.0/src}/agentensor/tensor.py +14 -8
  10. agentensor-0.1.0/src/agentensor/train.py +280 -0
  11. agentensor-0.0.4/.bumpversion.cfg +0 -12
  12. agentensor-0.0.4/.github/workflows/after-ci.yml +0 -50
  13. agentensor-0.0.4/.github/workflows/ci.yml +0 -103
  14. agentensor-0.0.4/.pre-commit-config.yaml +0 -26
  15. agentensor-0.0.4/.python-version +0 -1
  16. agentensor-0.0.4/LICENSE +0 -21
  17. agentensor-0.0.4/Makefile +0 -15
  18. agentensor-0.0.4/PKG-INFO +0 -45
  19. agentensor-0.0.4/agentensor/__init__.py +0 -1
  20. agentensor-0.0.4/agentensor/optim.py +0 -57
  21. agentensor-0.0.4/agentensor/train.py +0 -93
  22. agentensor-0.0.4/examples/evaluate.py +0 -154
  23. agentensor-0.0.4/examples/train.py +0 -118
  24. agentensor-0.0.4/mkdocs.yml +0 -1
  25. agentensor-0.0.4/tests/test_loss.py +0 -96
  26. agentensor-0.0.4/tests/test_module.py +0 -119
  27. agentensor-0.0.4/tests/test_optim.py +0 -130
  28. agentensor-0.0.4/tests/test_tensor.py +0 -156
  29. agentensor-0.0.4/tests/test_train.py +0 -266
  30. agentensor-0.0.4/uv.lock +0 -3431
  31. {agentensor-0.0.4 → agentensor-0.1.0}/README.md +0 -0
  32. {agentensor-0.0.4 → agentensor-0.1.0/src}/agentensor/loss.py +0 -0
@@ -0,0 +1,18 @@
1
+ [bumpversion]
2
+ current_version = 0.1.0
3
+ commit = True
4
+ tag = True
5
+ tag_name = agentensor-v{new_version}
6
+ message = Bump agentensor version to {new_version}
7
+
8
+ [bumpversion:file:pyproject.toml]
9
+ search = version = "{current_version}"
10
+ replace = version = "{new_version}"
11
+
12
+ [bumpversion:file:../../uv.lock]
13
+ search = [[package]]
14
+ name = "agentensor"
15
+ version = "{current_version}"
16
+ replace = [[package]]
17
+ name = "agentensor"
18
+ version = "{new_version}"
@@ -10,11 +10,10 @@ __pycache__/
10
10
  .Python
11
11
  build/
12
12
  develop-eggs/
13
- dist/
14
13
  downloads/
15
14
  eggs/
16
15
  .eggs/
17
- lib/
16
+ /lib/
18
17
  lib64/
19
18
  parts/
20
19
  sdist/
@@ -57,7 +56,6 @@ cover/
57
56
  *.pot
58
57
 
59
58
  # Django stuff:
60
- *.log
61
59
  local_settings.py
62
60
  db.sqlite3
63
61
  db.sqlite3-journal
@@ -153,12 +151,44 @@ dmypy.json
153
151
  # Cython debug symbols
154
152
  cython_debug/
155
153
 
156
- # PyCharm
157
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159
- # and can be added to the global gitignore or merged into this file. For a more nuclear
160
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161
- #.idea/
154
+ # Dependencies
155
+ node_modules/
156
+
157
+ # Logs
158
+ logs
159
+ *.log
160
+ npm-debug.log*
161
+ yarn-debug.log*
162
+ yarn-error.log*
163
+ pnpm-debug.log*
164
+ lerna-debug.log*
162
165
 
163
- # MacOS
166
+ # Build
167
+ dist/
168
+ dist-ssr/
169
+ *.local
170
+
171
+ # Editor directories and files
172
+ .vscode/*
173
+ !.vscode/extensions.json
174
+ .idea/
175
+ *.suo
176
+ *.ntvs*
177
+ *.njsproj
178
+ *.sln
179
+ *.sw?
180
+
181
+ # OS generated files
182
+ .DS_Store
164
183
  **/.DS_Store
184
+
185
+ # LangGraph
186
+ checkpoints.sqlite
187
+
188
+ # Orcheo
189
+ .orcheo/
190
+ **/workflow_config.json
191
+
192
+ # Miscellaneous
193
+ homepage/
194
+ orcheo-paper/
@@ -0,0 +1,40 @@
1
+ Metadata-Version: 2.4
2
+ Name: agentensor
3
+ Version: 0.1.0
4
+ Summary: Agent prompt tensors, modules, and optimizers for Orcheo workflows
5
+ Requires-Python: >=3.12
6
+ Requires-Dist: datasets>=3.5.0
7
+ Requires-Dist: langchain-ollama>=0.3.3
8
+ Requires-Dist: langchain-openai>=0.3.18
9
+ Requires-Dist: langchain>=1.1.3
10
+ Requires-Dist: langgraph>=1.0.5
11
+ Requires-Dist: logfire>=3.14.0
12
+ Requires-Dist: pydantic-ai>=0.2.4
13
+ Provides-Extra: dev
14
+ Requires-Dist: bump2version; extra == 'dev'
15
+ Requires-Dist: diff-cover>=9.2.4; extra == 'dev'
16
+ Requires-Dist: mypy>=1.11.2; extra == 'dev'
17
+ Requires-Dist: pre-commit; extra == 'dev'
18
+ Requires-Dist: pytest; extra == 'dev'
19
+ Requires-Dist: pytest-asyncio>=0.23.8; extra == 'dev'
20
+ Requires-Dist: pytest-cov>=4.1.0; extra == 'dev'
21
+ Requires-Dist: ruff>=0.11.3; extra == 'dev'
22
+ Requires-Dist: smokeshow>=0.5.0; extra == 'dev'
23
+ Requires-Dist: types-requests; extra == 'dev'
24
+ Provides-Extra: docs
25
+ Requires-Dist: mkdocs; extra == 'docs'
26
+ Requires-Dist: mkdocs-gen-files; extra == 'docs'
27
+ Requires-Dist: mkdocs-jupyter; extra == 'docs'
28
+ Requires-Dist: mkdocs-material; extra == 'docs'
29
+ Requires-Dist: mkdocstrings[python]>=0.28.1; extra == 'docs'
30
+ Description-Content-Type: text/markdown
31
+
32
+ # AgenTensor
33
+
34
+ [![CI](https://github.com/ShaojieJiang/agentensor/actions/workflows/ci.yml/badge.svg?event=push)](https://github.com/ShaojieJiang/agentensor/actions/workflows/ci.yml?query=branch%3Amain)
35
+ [![Coverage](https://coverage-badge.samuelcolvin.workers.dev/ShaojieJiang/agentensor.svg)](https://coverage-badge.samuelcolvin.workers.dev/redirect/ShaojieJiang/agentensor)
36
+ [![PyPI](https://img.shields.io/pypi/v/agentensor.svg)](https://pypi.python.org/pypi/agentensor)
37
+
38
+ ## TODO
39
+
40
+ - [ ] Add parameter saving
@@ -2,11 +2,27 @@
2
2
  build-backend = "hatchling.build"
3
3
  requires = ["hatchling"]
4
4
 
5
- [dependency-groups]
5
+ [project]
6
+ dependencies = [
7
+ "datasets>=3.5.0",
8
+ "langchain>=1.1.3",
9
+ "langchain-ollama>=0.3.3",
10
+ "langchain-openai>=0.3.18",
11
+ "langgraph>=1.0.5",
12
+ "logfire>=3.14.0",
13
+ "pydantic-ai>=0.2.4"
14
+ ]
15
+ description = "Agent prompt tensors, modules, and optimizers for Orcheo workflows"
16
+ name = "agentensor"
17
+ readme = "README.md"
18
+ requires-python = ">=3.12"
19
+ url = "https://github.com/ShaojieJiang/agentensor"
20
+ version = "0.1.0"
21
+
22
+ [project.optional-dependencies]
6
23
  dev = [
7
24
  "bump2version",
8
25
  "diff-cover>=9.2.4",
9
- "isort",
10
26
  "mypy>=1.11.2",
11
27
  "pre-commit",
12
28
  "pytest",
@@ -24,24 +40,6 @@ docs = [
24
40
  "mkdocstrings[python]>=0.28.1"
25
41
  ]
26
42
 
27
- [project]
28
- dependencies = [
29
- "datasets>=3.5.0",
30
- "langchain>=0.3.25",
31
- "langchain-ollama>=0.3.3",
32
- "langchain-openai>=0.3.18",
33
- "langgraph>=0.4.5",
34
- "logfire>=3.14.0",
35
- "pydantic-ai>=0.2.4"
36
- ]
37
- description = "Add your description here"
38
- license = {file = "LICENSE"}
39
- name = "agentensor"
40
- readme = "README.md"
41
- requires-python = ">=3.12"
42
- url = "https://github.com/ShaojieJiang/agentensor"
43
- version = "0.0.4"
44
-
45
43
  [tool.coverage.report]
46
44
  exclude_lines = [
47
45
  "pragma: no cover",
@@ -55,6 +53,9 @@ branch = true
55
53
  command_line = "-m pytest"
56
54
  source = ["agentensor"]
57
55
 
56
+ [tool.hatch.build.targets.wheel]
57
+ packages = ["src/agentensor"]
58
+
58
59
  [tool.mypy]
59
60
  disallow_untyped_defs = true
60
61
  ignore_missing_imports = true
@@ -0,0 +1,8 @@
1
+ """Agentensor core primitives."""
2
+
3
+ from agentensor.loss import LLMTensorJudge
4
+ from agentensor.optim import Optimizer
5
+ from agentensor.train import GraphTrainer, Trainer
6
+
7
+
8
+ __all__ = ["GraphTrainer", "LLMTensorJudge", "Optimizer", "Trainer"]
@@ -5,12 +5,15 @@ from typing import Any
5
5
  from langchain.chat_models import init_chat_model
6
6
  from langchain_core.language_models import BaseChatModel
7
7
  from langchain_core.messages import HumanMessage
8
- from langgraph.graph.graph import CompiledGraph
8
+ from langgraph.graph.state import CompiledStateGraph
9
9
  from pydantic import BaseModel, ConfigDict
10
10
  from pydantic_ai.exceptions import UnexpectedModelBehavior
11
11
  from agentensor.tensor import TextTensor
12
12
 
13
13
 
14
+ CompiledGraph = CompiledStateGraph[Any, Any, Any, Any]
15
+
16
+
14
17
  class AgentModule(BaseModel, ABC):
15
18
  """Agent module."""
16
19
 
@@ -0,0 +1,109 @@
1
+ """Optimizer module."""
2
+
3
+ from langchain.agents import create_agent
4
+ from langchain.chat_models import init_chat_model
5
+ from langchain_core.language_models import BaseChatModel
6
+ from langchain_core.messages import HumanMessage
7
+ from langchain_core.runnables import Runnable
8
+ from langgraph.graph import StateGraph
9
+ from agentensor.module import AgentModule
10
+ from agentensor.tensor import TextTensor
11
+
12
+
13
+ class Optimizer:
14
+ """Optimizer class."""
15
+
16
+ def __init__(
17
+ self,
18
+ graph: StateGraph | None = None,
19
+ model: str | BaseChatModel = "gpt-4o-mini",
20
+ params: list[TextTensor] | None = None,
21
+ ) -> None:
22
+ """Initialize the optimizer."""
23
+ self.params: list[TextTensor] = (
24
+ self._coerce_params(params, source="params") if params is not None else []
25
+ )
26
+ if graph is not None and not params:
27
+ for node in graph.nodes.values():
28
+ runnable = getattr(node, "runnable", None)
29
+ module: AgentModule | None = None
30
+
31
+ if isinstance(runnable, AgentModule):
32
+ module = runnable
33
+ else:
34
+ function = getattr(runnable, "afunc", None)
35
+ if isinstance(function, AgentModule):
36
+ module = function
37
+ else:
38
+ bound_self = getattr(function, "__self__", None)
39
+ if isinstance(bound_self, AgentModule):
40
+ module = bound_self
41
+
42
+ if module is not None:
43
+ self.params.extend(
44
+ self._coerce_params(
45
+ module.get_params(),
46
+ source=f"{module.__class__.__name__}.get_params()",
47
+ )
48
+ )
49
+ continue
50
+
51
+ param_provider = getattr(runnable, "get_params", None)
52
+ if callable(param_provider):
53
+ self.params.extend(
54
+ self._coerce_params(
55
+ param_provider(),
56
+ source=f"{runnable.__class__.__name__}.get_params()",
57
+ )
58
+ )
59
+ if isinstance(model, str):
60
+ self.model = init_chat_model(model)
61
+ else: # pragma: no cover
62
+ self.model = model
63
+
64
+ def step(self) -> None:
65
+ """Step the optimizer."""
66
+ for param in self.params:
67
+ if not param.text_grad:
68
+ continue
69
+ param.text = self.optimize(param.text, param.text_grad)
70
+
71
+ def zero_grad(self) -> None:
72
+ """Zero the gradients."""
73
+ for param in self.params:
74
+ param.zero_grad()
75
+
76
+ def optimize(self, text: str, grad: str) -> str:
77
+ """Optimize the text."""
78
+ result = self.agent.invoke(
79
+ {"messages": [HumanMessage(content=f"Feedback: {grad}\nText: {text}")]}
80
+ )
81
+ return result["messages"][-1].content
82
+
83
+ @property
84
+ def agent(self) -> Runnable:
85
+ """Get the agent."""
86
+ return create_agent(
87
+ self.model,
88
+ tools=[],
89
+ system_prompt="Rewrite the system prompt given the feedback.",
90
+ )
91
+
92
+ @staticmethod
93
+ def _coerce_params(params: object, *, source: str) -> list[TextTensor]:
94
+ """Normalize parameters and raise if they are incompatible."""
95
+ if isinstance(params, TextTensor):
96
+ return [params]
97
+ if isinstance(params, list | tuple):
98
+ if not params:
99
+ return []
100
+ invalid = [param for param in params if not isinstance(param, TextTensor)]
101
+ if invalid:
102
+ raise TypeError(
103
+ f"{source} must contain only TextTensor instances, got "
104
+ f"{type(invalid[0]).__name__}."
105
+ )
106
+ return list(params)
107
+ raise TypeError(
108
+ f"{source} must return a TextTensor or list of TextTensor objects."
109
+ )
File without changes
@@ -1,11 +1,12 @@
1
- """Example module."""
1
+ """Text tensor primitives."""
2
2
 
3
3
  from __future__ import annotations
4
+ from typing import Any
5
+ from langchain.agents import create_agent
4
6
  from langchain.chat_models import init_chat_model
5
7
  from langchain_core.language_models import BaseChatModel
6
8
  from langchain_core.messages import HumanMessage
7
- from langgraph.graph.graph import CompiledGraph
8
- from langgraph.prebuilt import create_react_agent
9
+ from langchain_core.runnables import Runnable
9
10
 
10
11
 
11
12
  class TextTensor:
@@ -16,15 +17,18 @@ class TextTensor:
16
17
  text: str,
17
18
  parents: list[TextTensor] | None = None,
18
19
  requires_grad: bool = False,
19
- model: str | BaseChatModel = "gpt-4o-mini",
20
+ metadata: dict[str, Any] | None = None,
21
+ model: str | BaseChatModel = "openai:gpt-4o-mini",
22
+ model_kwargs: dict[str, Any] | None = None,
20
23
  ) -> None:
21
24
  """Initialize a TextTensor."""
22
25
  self.text = text
23
26
  self.requires_grad = requires_grad
24
27
  self.gradients: list[str] = []
28
+ self.metadata = dict(metadata) if metadata is not None else {}
25
29
  self.parents: list[TextTensor] = parents or []
26
30
  if isinstance(model, str):
27
- self.model = init_chat_model(model)
31
+ self.model = init_chat_model(model, **(model_kwargs or {}))
28
32
  else:
29
33
  self.model = model
30
34
 
@@ -77,8 +81,10 @@ class TextTensor:
77
81
  return self.text
78
82
 
79
83
  @property
80
- def agent(self) -> CompiledGraph:
84
+ def agent(self) -> Runnable:
81
85
  """Get the agent."""
82
- return create_react_agent(
83
- self.model, tools=[], prompt="Answer the user's question."
86
+ return create_agent(
87
+ self.model,
88
+ tools=[],
89
+ system_prompt="Answer the user's question.",
84
90
  )
@@ -0,0 +1,280 @@
1
+ """Trainer."""
2
+
3
+ from __future__ import annotations
4
+ import asyncio
5
+ import json
6
+ from collections.abc import Mapping
7
+ from typing import Any, Literal
8
+ from langchain_core.runnables import RunnableConfig
9
+ from langgraph.graph.state import CompiledStateGraph
10
+ from pydantic_evals import Dataset
11
+ from pydantic_evals.reporting import EvaluationReport
12
+ from agentensor.optim import Optimizer
13
+ from agentensor.tensor import TextTensor
14
+
15
+
16
+ CompiledGraph = CompiledStateGraph[Any, Any, Any, Any]
17
+
18
+
19
+ class Trainer:
20
+ """Trainer."""
21
+
22
+ def __init__(
23
+ self,
24
+ graph: CompiledGraph,
25
+ graph_recursion_limit: int = 25,
26
+ train_dataset: Dataset[TextTensor, TextTensor, Any] | None = None,
27
+ eval_dataset: Dataset[TextTensor, TextTensor, Any] | None = None,
28
+ test_dataset: Dataset[TextTensor, TextTensor, Any] | None = None,
29
+ optimizer: Optimizer | None = None,
30
+ epochs: int = 10,
31
+ stop_threshold: float = 0.95,
32
+ ):
33
+ """Initialize the trainer."""
34
+ self.graph = graph
35
+ self.graph_recursion_limit = graph_recursion_limit
36
+ self.optimizer = optimizer
37
+ self.epochs = epochs
38
+ self.stop_threshold = stop_threshold
39
+ self.train_dataset = train_dataset
40
+ self.eval_dataset = eval_dataset
41
+ self.test_dataset = test_dataset
42
+
43
+ async def forward(self, x: TextTensor) -> TextTensor:
44
+ """Forward the graph."""
45
+ result = await self.graph.ainvoke(
46
+ {"output": x}, {"recursion_limit": self.graph_recursion_limit}
47
+ )
48
+ return result["output"]
49
+
50
+ def train(self) -> None:
51
+ """Train the graph."""
52
+ self._require_dataset("train")
53
+ optimizer = self._require_optimizer()
54
+ for i in range(self.epochs):
55
+ report = self.evaluate("train")
56
+ report.print(
57
+ include_input=True, include_output=True, include_durations=True
58
+ )
59
+
60
+ # Backward those failed cases
61
+ for case in report.cases:
62
+ losses: list[str] = []
63
+ for evaluator in case.assertions.values():
64
+ if not evaluator.value:
65
+ reason = getattr(evaluator, "reason", None)
66
+ if reason is None or (
67
+ isinstance(reason, str) and not reason.strip()
68
+ ):
69
+ losses.append("Evaluation failed without a reason.")
70
+ else:
71
+ losses.append(str(reason))
72
+ if losses:
73
+ case.output.backward(" ".join(losses))
74
+
75
+ optimizer.step()
76
+ optimizer.zero_grad()
77
+
78
+ performance = report.averages()
79
+ assertions = None if performance is None else performance.assertions
80
+ self.after_epoch(i, report)
81
+ if assertions is not None and assertions >= self.stop_threshold:
82
+ print("Optimization complete.")
83
+ break
84
+
85
+ def evaluate(
86
+ self,
87
+ data_split: Literal["train", "eval", "test"] = "eval",
88
+ limit_cases: int | None = None,
89
+ max_concurrency: int | None = None,
90
+ progress: bool = True,
91
+ ) -> EvaluationReport:
92
+ """Evaluate the graph."""
93
+ dataset = self._require_dataset(data_split)
94
+ if limit_cases:
95
+ limited_cases = dataset.cases[:limit_cases]
96
+ dataset = Dataset(cases=limited_cases, evaluators=dataset.evaluators)
97
+ report = dataset.evaluate_sync(
98
+ self.forward,
99
+ max_concurrency=max_concurrency,
100
+ progress=progress,
101
+ )
102
+
103
+ return report
104
+
105
+ def test(self, limit_cases: int | None = None) -> None:
106
+ """Test the graph."""
107
+ report = self.evaluate("test", limit_cases=limit_cases)
108
+ report.print(include_input=True, include_output=True, include_durations=True)
109
+
110
+ def after_epoch(self, epoch_index: int, report: EvaluationReport) -> None:
111
+ """Optional hook for subclasses to record state."""
112
+ return None
113
+
114
+ def _require_dataset(self, data_split: str) -> Dataset[Any, Any, Any]:
115
+ """Return the dataset for a split or raise a descriptive error."""
116
+ dataset = getattr(self, f"{data_split}_dataset", None)
117
+ if dataset is None:
118
+ raise ValueError(f"{data_split} dataset is required")
119
+ return dataset
120
+
121
+ def _require_optimizer(self) -> Optimizer:
122
+ """Return the optimizer or raise a descriptive error."""
123
+ if self.optimizer is None:
124
+ raise ValueError("Optimizer is required")
125
+ return self.optimizer
126
+
127
+
128
+ class GraphTrainer(Trainer):
129
+ """Trainer that runs a compiled graph against mapping inputs."""
130
+
131
+ def __init__(
132
+ self,
133
+ *,
134
+ graph: CompiledGraph,
135
+ dataset: Dataset[Any, Any, Any],
136
+ optimizer: Optimizer | None = None,
137
+ epochs: int,
138
+ runtime_prompts: Mapping[str, TextTensor],
139
+ base_state: Mapping[str, Any] | None = None,
140
+ graph_config: RunnableConfig | None = None,
141
+ max_concurrency: int = 1,
142
+ case_timeout: int = 30,
143
+ stop_threshold: float = 2.0,
144
+ script_format: bool = False,
145
+ ) -> None:
146
+ """Initialize the graph trainer."""
147
+ super().__init__(
148
+ graph=graph,
149
+ train_dataset=dataset,
150
+ eval_dataset=dataset,
151
+ optimizer=optimizer,
152
+ epochs=epochs,
153
+ stop_threshold=stop_threshold,
154
+ )
155
+ self.runtime_prompts = runtime_prompts
156
+ self.base_state = base_state or {}
157
+ self.graph_config = graph_config
158
+ self.max_concurrency = max_concurrency
159
+ self.case_timeout = case_timeout
160
+ self.reports: list[EvaluationReport] = []
161
+ self.script_format = script_format
162
+ self.prompt_history: list[dict[str, str]] = []
163
+
164
+ async def forward(self, case_inputs: Mapping[str, Any]) -> TextTensor: # type: ignore[override]
165
+ """Execute the compiled graph and return a tensor with the raw payload."""
166
+ merged_inputs = self._merge_inputs(case_inputs)
167
+ case_state = self._build_case_state(merged_inputs)
168
+ output_state = await asyncio.wait_for(
169
+ self.graph.ainvoke(case_state, config=self.graph_config),
170
+ timeout=self.case_timeout,
171
+ )
172
+ output_payload = self._extract_output(output_state)
173
+ parents = [
174
+ prompt for prompt in self.runtime_prompts.values() if prompt.requires_grad
175
+ ]
176
+ tensor = TextTensor(
177
+ text=self._stringify_output(output_payload),
178
+ requires_grad=True,
179
+ parents=parents,
180
+ metadata={"payload": output_payload},
181
+ )
182
+ return tensor
183
+
184
+ def evaluate(
185
+ self,
186
+ data_split: Literal["train", "eval", "test"] = "eval",
187
+ limit_cases: int | None = None,
188
+ max_concurrency: int | None = None,
189
+ progress: bool = False,
190
+ ) -> EvaluationReport:
191
+ """Run evaluation without rendering progress to stdout."""
192
+ dataset = self._require_dataset(data_split)
193
+ cases = dataset.cases
194
+ if limit_cases is not None:
195
+ cases = cases[:limit_cases]
196
+ dataset = Dataset(cases=cases, evaluators=dataset.evaluators)
197
+ report = dataset.evaluate_sync(
198
+ self.forward,
199
+ max_concurrency=max_concurrency or self.max_concurrency,
200
+ progress=progress,
201
+ )
202
+ self.reports.append(report)
203
+ return report
204
+
205
+ def _merge_inputs(self, inputs: Mapping[str, Any]) -> dict[str, Any]:
206
+ merged: dict[str, Any] = {}
207
+ if isinstance(self.base_state, Mapping):
208
+ base_inputs = self.base_state.get("inputs", self.base_state)
209
+ if isinstance(base_inputs, Mapping): # pragma: no branch
210
+ merged.update(base_inputs)
211
+ merged.update(inputs)
212
+ return merged
213
+
214
+ def _build_case_state(self, inputs: Mapping[str, Any]) -> dict[str, Any]:
215
+ runtime_config = (
216
+ dict(self.base_state.get("config", {}))
217
+ if isinstance(self.base_state, Mapping)
218
+ else {}
219
+ )
220
+ if self.script_format:
221
+ state = dict(inputs)
222
+ state["config"] = runtime_config | {"prompts": self.runtime_prompts}
223
+ return state
224
+ return {
225
+ "messages": [],
226
+ "results": {},
227
+ "inputs": dict(inputs),
228
+ "structured_response": None,
229
+ "config": runtime_config | {"prompts": self.runtime_prompts},
230
+ }
231
+
232
+ @staticmethod
233
+ def _stringify_output(output: Any) -> str:
234
+ if isinstance(output, str):
235
+ return output
236
+ try:
237
+ return json.dumps(output)
238
+ except TypeError:
239
+ return str(output)
240
+
241
+ @staticmethod
242
+ def _extract_output(output_state: Any) -> Any:
243
+ if isinstance(output_state, Mapping): # pragma: no branch
244
+ results = output_state.get("results")
245
+ if isinstance(results, Mapping) and results:
246
+ return results
247
+ if "output" in output_state:
248
+ return output_state["output"]
249
+ message_output = GraphTrainer._extract_message_output(
250
+ output_state.get("messages")
251
+ )
252
+ if message_output is not None:
253
+ return message_output
254
+ return output_state
255
+
256
+ @staticmethod
257
+ def _extract_message_output(messages: Any) -> Any | None:
258
+ if not isinstance(messages, list):
259
+ return None
260
+ fallback: Any | None = None
261
+ for message in reversed(messages):
262
+ if isinstance(message, Mapping):
263
+ content = message.get("content")
264
+ role = message.get("role") or message.get("type")
265
+ else:
266
+ content = getattr(message, "content", None)
267
+ role = getattr(message, "role", None) or getattr(message, "type", None)
268
+ if role in {"assistant", "ai"}:
269
+ return content
270
+ if fallback is None and content is not None: # pragma: no branch
271
+ if not (isinstance(content, str) and not content.strip()):
272
+ fallback = content
273
+ return fallback
274
+
275
+ def after_epoch(self, epoch_index: int, report: EvaluationReport) -> None:
276
+ """Record prompt snapshots after each optimizer step."""
277
+ snapshot: dict[str, str] = {
278
+ name: tensor.text for name, tensor in self.runtime_prompts.items()
279
+ }
280
+ self.prompt_history.append(snapshot)
@@ -1,12 +0,0 @@
1
- [bumpversion]
2
- current_version = 0.0.4
3
- commit = True
4
- tag = True
5
-
6
- [bumpversion:file:pyproject.toml]
7
- search = version = "{current_version}"
8
- replace = version = "{new_version}"
9
-
10
- [bumpversion:file:uv.lock]
11
- search = version = "{current_version}"
12
- replace = version = "{new_version}"