agentensor 0.0.1__tar.gz → 0.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {agentensor-0.0.1 → agentensor-0.0.3}/.bumpversion.cfg +1 -1
  2. {agentensor-0.0.1 → agentensor-0.0.3}/PKG-INFO +4 -2
  3. agentensor-0.0.3/agentensor/module.py +49 -0
  4. {agentensor-0.0.1 → agentensor-0.0.3}/agentensor/optim.py +12 -8
  5. {agentensor-0.0.1 → agentensor-0.0.3}/agentensor/tensor.py +16 -5
  6. agentensor-0.0.3/agentensor/train.py +93 -0
  7. agentensor-0.0.3/examples/evaluate.py +151 -0
  8. agentensor-0.0.3/examples/train.py +111 -0
  9. {agentensor-0.0.1 → agentensor-0.0.3}/pyproject.toml +6 -5
  10. {agentensor-0.0.1 → agentensor-0.0.3}/tests/test_module.py +37 -37
  11. {agentensor-0.0.1 → agentensor-0.0.3}/tests/test_optim.py +11 -31
  12. {agentensor-0.0.1 → agentensor-0.0.3}/tests/test_tensor.py +2 -2
  13. {agentensor-0.0.1 → agentensor-0.0.3}/tests/test_train.py +41 -25
  14. {agentensor-0.0.1 → agentensor-0.0.3}/uv.lock +827 -67
  15. agentensor-0.0.1/agentensor/module.py +0 -26
  16. agentensor-0.0.1/agentensor/train.py +0 -66
  17. agentensor-0.0.1/examples/example.py +0 -94
  18. {agentensor-0.0.1 → agentensor-0.0.3}/.github/workflows/after-ci.yml +0 -0
  19. {agentensor-0.0.1 → agentensor-0.0.3}/.github/workflows/ci.yml +0 -0
  20. {agentensor-0.0.1 → agentensor-0.0.3}/.gitignore +0 -0
  21. {agentensor-0.0.1 → agentensor-0.0.3}/.pre-commit-config.yaml +0 -0
  22. {agentensor-0.0.1 → agentensor-0.0.3}/.python-version +0 -0
  23. {agentensor-0.0.1 → agentensor-0.0.3}/LICENSE +0 -0
  24. {agentensor-0.0.1 → agentensor-0.0.3}/Makefile +0 -0
  25. {agentensor-0.0.1 → agentensor-0.0.3}/README.md +0 -0
  26. {agentensor-0.0.1 → agentensor-0.0.3}/agentensor/__init__.py +0 -0
  27. {agentensor-0.0.1 → agentensor-0.0.3}/agentensor/loss.py +0 -0
  28. {agentensor-0.0.1 → agentensor-0.0.3}/mkdocs.yml +0 -0
  29. {agentensor-0.0.1 → agentensor-0.0.3}/tests/test_loss.py +0 -0
@@ -1,5 +1,5 @@
1
1
  [bumpversion]
2
- current_version = 0.0.1
2
+ current_version = 0.0.3
3
3
  commit = True
4
4
  tag = True
5
5
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: agentensor
3
- Version: 0.0.1
3
+ Version: 0.0.3
4
4
  Summary: Add your description here
5
5
  License: MIT License
6
6
 
@@ -25,8 +25,10 @@ License: MIT License
25
25
  SOFTWARE.
26
26
  License-File: LICENSE
27
27
  Requires-Python: >=3.12
28
+ Requires-Dist: datasets>=3.5.0
29
+ Requires-Dist: langgraph>=0.4.5
28
30
  Requires-Dist: logfire>=3.14.0
29
- Requires-Dist: pydantic-ai>=0.0.55
31
+ Requires-Dist: pydantic-ai>=0.2.4
30
32
  Description-Content-Type: text/markdown
31
33
 
32
34
  # AgenTensor
@@ -0,0 +1,49 @@
1
+ """Module class."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from pydantic import BaseModel, ConfigDict
5
+ from pydantic_ai import Agent, models
6
+ from pydantic_ai.exceptions import UnexpectedModelBehavior
7
+ from agentensor.tensor import TextTensor
8
+
9
+
10
+ class AgentModule(BaseModel, ABC):
11
+ """Agent module."""
12
+
13
+ model_config = ConfigDict(arbitrary_types_allowed=True)
14
+
15
+ system_prompt: TextTensor
16
+ model: models.Model | models.KnownModelName | str = "openai:gpt-4o"
17
+
18
+ def get_params(self) -> list[TextTensor]:
19
+ """Get the parameters of the module."""
20
+ params = []
21
+ for field_name in self.__class__.model_fields.keys():
22
+ field = getattr(self, field_name)
23
+ if isinstance(field, TextTensor) and field.requires_grad:
24
+ params.append(field)
25
+ return params
26
+
27
+ async def __call__(self, state: dict) -> dict:
28
+ """Run the agent node."""
29
+ assert state["output"]
30
+ agent = self.get_agent()
31
+ try:
32
+ result = await agent.run(state["output"].text)
33
+ output = str(result.output)
34
+ except UnexpectedModelBehavior: # pragma: no cover
35
+ output = "Error"
36
+
37
+ output_tensor = TextTensor(
38
+ output,
39
+ parents=[state["output"], self.system_prompt],
40
+ requires_grad=True,
41
+ model=self.model,
42
+ )
43
+
44
+ return {"output": output_tensor}
45
+
46
+ @abstractmethod
47
+ def get_agent(self) -> Agent:
48
+ """Get agent instance."""
49
+ pass # pragma: no cover
@@ -1,7 +1,7 @@
1
1
  """Optimizer module."""
2
2
 
3
- from pydantic_ai import Agent
4
- from pydantic_graph import Graph
3
+ from langgraph.graph import StateGraph
4
+ from pydantic_ai import Agent, models
5
5
  from agentensor.module import AgentModule
6
6
  from agentensor.tensor import TextTensor
7
7
 
@@ -9,16 +9,20 @@ from agentensor.tensor import TextTensor
9
9
  class Optimizer:
10
10
  """Optimizer class."""
11
11
 
12
- def __init__(self, graph: Graph) -> None:
12
+ def __init__(
13
+ self,
14
+ graph: StateGraph,
15
+ model: models.Model | models.KnownModelName | str | None = None,
16
+ ) -> None:
13
17
  """Initialize the optimizer."""
14
18
  self.params: list[TextTensor] = [
15
19
  param
16
- for node in graph.get_nodes()
17
- for param in node.get_params() # type: ignore[attr-defined]
18
- if issubclass(node, AgentModule)
20
+ for node in graph.nodes.values()
21
+ if isinstance(node.runnable.afunc, AgentModule) # type: ignore[attr-defined]
22
+ for param in node.runnable.afunc.get_params() # type: ignore[attr-defined]
19
23
  ]
20
24
  self.agent: Agent = Agent(
21
- model="openai:gpt-4o-mini",
25
+ model=model or "openai:gpt-4o-mini",
22
26
  system_prompt="Rewrite the system prompt given the feedback.",
23
27
  )
24
28
 
@@ -32,7 +36,7 @@ class Optimizer:
32
36
  def zero_grad(self) -> None:
33
37
  """Zero the gradients."""
34
38
  for param in self.params:
35
- param.text_grad = ""
39
+ param.zero_grad()
36
40
 
37
41
  def optimize(self, text: str, grad: str) -> str:
38
42
  """Optimize the text."""
@@ -1,7 +1,7 @@
1
1
  """Example module."""
2
2
 
3
3
  from __future__ import annotations
4
- from pydantic_ai import Agent
4
+ from pydantic_ai import Agent, models
5
5
 
6
6
 
7
7
  class TextTensor:
@@ -12,13 +12,15 @@ class TextTensor:
12
12
  text: str,
13
13
  parents: list[TextTensor] | None = None,
14
14
  requires_grad: bool = False,
15
+ model: models.Model | models.KnownModelName | str | None = None,
15
16
  ) -> None:
16
17
  """Initialize a TextTensor."""
17
18
  self.text = text
18
19
  self.requires_grad = requires_grad
19
- self.text_grad = ""
20
+ self.gradients: list[str] = []
20
21
  self.agent = Agent(
21
- model="openai:gpt-4o-mini", system_prompt="Answer the user's question."
22
+ model=model or "openai:gpt-4o-mini",
23
+ system_prompt="Answer the user's question.",
22
24
  )
23
25
  self.parents: list[TextTensor] = parents or []
24
26
 
@@ -32,7 +34,7 @@ class TextTensor:
32
34
  return
33
35
 
34
36
  if self.requires_grad:
35
- self.text_grad = grad
37
+ self.gradients.append(grad)
36
38
  for parent in self.parents:
37
39
  if not parent.requires_grad:
38
40
  continue
@@ -46,7 +48,16 @@ class TextTensor:
46
48
  f"output: \n\n>{output_text}\n\nHere is the feedback: \n\n"
47
49
  f">{grad}\n\nHow should I improve the input to get a "
48
50
  f"better output?"
49
- ).data
51
+ ).output
52
+
53
+ @property
54
+ def text_grad(self) -> str:
55
+ """String representation of the gradients."""
56
+ return " ".join(self.gradients)
57
+
58
+ def zero_grad(self) -> None:
59
+ """Zero the gradients."""
60
+ self.gradients = []
50
61
 
51
62
  def __str__(self) -> str:
52
63
  """Return the text as a string."""
@@ -0,0 +1,93 @@
1
+ """Trainer."""
2
+
3
+ from typing import Any, Literal
4
+ from langgraph.graph.graph import CompiledGraph
5
+ from pydantic_evals import Dataset
6
+ from pydantic_evals.reporting import EvaluationReport
7
+ from agentensor.optim import Optimizer
8
+ from agentensor.tensor import TextTensor
9
+
10
+
11
+ class Trainer:
12
+ """Trainer."""
13
+
14
+ def __init__(
15
+ self,
16
+ graph: CompiledGraph,
17
+ graph_recursion_limit: int = 25,
18
+ train_dataset: Dataset[TextTensor, TextTensor, Any] | None = None,
19
+ eval_dataset: Dataset[TextTensor, TextTensor, Any] | None = None,
20
+ test_dataset: Dataset[TextTensor, TextTensor, Any] | None = None,
21
+ optimizer: Optimizer | None = None,
22
+ epochs: int = 10,
23
+ stop_threshold: float = 0.95,
24
+ ):
25
+ """Initialize the trainer."""
26
+ self.graph = graph
27
+ self.graph_recursion_limit = graph_recursion_limit
28
+ self.optimizer = optimizer
29
+ self.epochs = epochs
30
+ self.stop_threshold = stop_threshold
31
+ self.train_dataset = train_dataset
32
+ self.eval_dataset = eval_dataset
33
+ self.test_dataset = test_dataset
34
+
35
+ async def forward(self, x: TextTensor) -> TextTensor:
36
+ """Forward the graph."""
37
+ result = await self.graph.ainvoke(
38
+ {"output": x}, {"recursion_limit": self.graph_recursion_limit}
39
+ )
40
+ return result["output"]
41
+
42
+ def train(self) -> None:
43
+ """Train the graph."""
44
+ assert self.train_dataset, "Train dataset is required"
45
+ assert self.optimizer, "Optimizer is required"
46
+ for i in range(self.epochs):
47
+ report = self.evaluate("train")
48
+ report.print(
49
+ include_input=True, include_output=True, include_durations=True
50
+ )
51
+
52
+ # Backward those failed cases
53
+ for case in report.cases:
54
+ losses = []
55
+ for evaluator in case.assertions.values():
56
+ if not evaluator.value:
57
+ assert evaluator.reason
58
+ losses.append(evaluator.reason)
59
+ if losses:
60
+ case.output.backward(" ".join(losses))
61
+
62
+ self.optimizer.step()
63
+ self.optimizer.zero_grad()
64
+
65
+ print(f"Epoch {i + 1}")
66
+ for param in self.optimizer.params:
67
+ print(param.text) # pragma: no cover
68
+ print()
69
+ performance = report.averages().assertions
70
+ assert performance is not None
71
+ if performance >= self.stop_threshold:
72
+ print("Optimization complete.")
73
+ break
74
+
75
+ def evaluate(
76
+ self,
77
+ data_split: Literal["train", "eval", "test"] = "eval",
78
+ limit_cases: int | None = None,
79
+ ) -> EvaluationReport:
80
+ """Evaluate the graph."""
81
+ dataset = getattr(self, f"{data_split}_dataset")
82
+ assert dataset, f"{data_split} dataset is required"
83
+ if limit_cases: # pragma: no cover
84
+ limited_cases = dataset.cases[:limit_cases]
85
+ dataset = Dataset(cases=limited_cases, evaluators=dataset.evaluators)
86
+ report = dataset.evaluate_sync(self.forward)
87
+
88
+ return report
89
+
90
+ def test(self, limit_cases: int | None = None) -> None:
91
+ """Test the graph."""
92
+ report = self.evaluate("test", limit_cases=limit_cases)
93
+ report.print(include_input=True, include_output=True, include_durations=True)
@@ -0,0 +1,151 @@
1
+ """Tasks."""
2
+
3
+ from __future__ import annotations
4
+ import json
5
+ from dataclasses import dataclass
6
+ from typing import TypedDict
7
+ from datasets import load_dataset
8
+ from langgraph.graph import END, START, StateGraph
9
+ from pydantic import BaseModel
10
+ from pydantic_ai import Agent, models
11
+ from pydantic_ai.models.openai import OpenAIModel
12
+ from pydantic_ai.providers.openai import OpenAIProvider
13
+ from pydantic_evals import Case, Dataset
14
+ from pydantic_evals.evaluators import EvaluationReason, Evaluator, EvaluatorContext
15
+ from agentensor.module import AgentModule
16
+ from agentensor.tensor import TextTensor
17
+ from agentensor.train import Trainer
18
+
19
+
20
+ @dataclass
21
+ class GenerationTimeout(Evaluator[str, bool]):
22
+ """The generation took too long."""
23
+
24
+ threshold: float = 10.0
25
+
26
+ async def evaluate(self, ctx: EvaluatorContext[str, bool]) -> EvaluationReason:
27
+ """Evaluate the time taken to generate the output."""
28
+ return EvaluationReason(
29
+ value=ctx.duration <= self.threshold,
30
+ reason=(
31
+ f"The generation took {ctx.duration} seconds, which is longer "
32
+ f"than the threshold of {self.threshold} seconds."
33
+ ),
34
+ )
35
+
36
+
37
+ @dataclass
38
+ class MultiLabelClassificationAccuracy(Evaluator):
39
+ """Classification accuracy evaluator."""
40
+
41
+ async def evaluate(self, ctx: EvaluatorContext) -> bool:
42
+ """Evaluate the accuracy of the classification."""
43
+ try:
44
+ output = json.loads(ctx.output.text)
45
+ except json.JSONDecodeError:
46
+ return False
47
+ expected = ctx.expected_output
48
+ return set(output) == set(expected) # type: ignore[arg-type]
49
+
50
+
51
+ class EvaluateState(TypedDict):
52
+ """State of the graph."""
53
+
54
+ output: TextTensor
55
+
56
+
57
+ class ClassificationResults(BaseModel, use_attribute_docstrings=True):
58
+ """Classification result for a data."""
59
+
60
+ labels: list[str]
61
+ """labels for this data point."""
62
+
63
+ def __str__(self) -> str:
64
+ """Return the string representation of the classification results."""
65
+ return json.dumps(self.labels)
66
+
67
+
68
+ class HFMultiClassClassificationTask:
69
+ """Multi-class classification task from Hugging Face."""
70
+
71
+ def __init__(
72
+ self,
73
+ task_repo: str,
74
+ evaluators: list[Evaluator],
75
+ model: models.Model | models.KnownModelName | str | None = None,
76
+ ) -> None:
77
+ """Initialize the multi-class classification task."""
78
+ self.task_repo = task_repo
79
+ self.evaluators = evaluators
80
+ self.model = model
81
+ self.dataset = self._prepare_dataset()
82
+
83
+ def _prepare_dataset(self) -> dict[str, Dataset]:
84
+ """Return the Pydantic Evals dataset."""
85
+ hf_dataset = load_dataset(self.task_repo, trust_remote_code=True)
86
+ dataset = {}
87
+ for split in hf_dataset.keys():
88
+ cases = []
89
+ for example in hf_dataset[split]:
90
+ cases.append(
91
+ Case(
92
+ inputs=TextTensor(
93
+ f"Title: {example['title']}\nContent: {example['content']}",
94
+ model=self.model,
95
+ ),
96
+ expected_output=example["all_labels"],
97
+ )
98
+ )
99
+ dataset[split] = Dataset(cases=cases, evaluators=self.evaluators)
100
+ return dataset
101
+
102
+
103
+ class AgentNode(AgentModule):
104
+ """Agent node."""
105
+
106
+ def get_agent(self) -> Agent:
107
+ """Get agent instance."""
108
+ return Agent(
109
+ model=self.model or "openai:gpt-4o-mini",
110
+ system_prompt=self.system_prompt.text,
111
+ output_type=ClassificationResults, # type: ignore[arg-type]
112
+ )
113
+
114
+
115
+ if __name__ == "__main__":
116
+ model = OpenAIModel(
117
+ model_name="llama3.2:1b",
118
+ provider=OpenAIProvider(base_url="http://localhost:11434/v1", api_key="ollama"),
119
+ )
120
+ # model = "openai:gpt-4o-mini"
121
+
122
+ task = HFMultiClassClassificationTask(
123
+ task_repo="knowledgator/events_classification_biotech",
124
+ evaluators=[GenerationTimeout(), MultiLabelClassificationAccuracy()],
125
+ model=model,
126
+ )
127
+ graph = StateGraph(EvaluateState)
128
+ graph.add_node(
129
+ "agent",
130
+ AgentNode(
131
+ system_prompt=TextTensor(
132
+ (
133
+ "Classify the following text into one of the following "
134
+ "categories: [expanding industry, new initiatives or programs, "
135
+ "article publication, other]"
136
+ ),
137
+ requires_grad=True,
138
+ model=model,
139
+ ),
140
+ model=model,
141
+ ),
142
+ )
143
+ graph.add_edge(START, "agent")
144
+ graph.add_edge("agent", END)
145
+ compiled_graph = graph.compile()
146
+ trainer = Trainer(
147
+ compiled_graph,
148
+ train_dataset=task.dataset["train"],
149
+ test_dataset=task.dataset["test"],
150
+ )
151
+ trainer.test(limit_cases=10)
@@ -0,0 +1,111 @@
1
+ """Example usage of agentensor."""
2
+
3
+ from __future__ import annotations
4
+ import os
5
+ from dataclasses import dataclass
6
+ from typing import Any, TypedDict
7
+ from langgraph.graph import END, START, StateGraph
8
+ from pydantic_ai import Agent, models
9
+ from pydantic_ai.models.openai import OpenAIModel
10
+ from pydantic_ai.providers.openai import OpenAIProvider
11
+ from pydantic_evals import Case, Dataset
12
+ from agentensor.loss import LLMTensorJudge
13
+ from agentensor.module import AgentModule
14
+ from agentensor.optim import Optimizer
15
+ from agentensor.tensor import TextTensor
16
+ from agentensor.train import Trainer
17
+
18
+
19
+ @dataclass
20
+ class ChineseLanguageJudge(LLMTensorJudge):
21
+ """Chinese language judge."""
22
+
23
+ rubric: str = "The output should be in Chinese."
24
+ model: models.Model | models.KnownModelName = "openai:gpt-4o-mini"
25
+ include_input = True
26
+
27
+
28
+ @dataclass
29
+ class FormatJudge(LLMTensorJudge):
30
+ """Format judge."""
31
+
32
+ rubric: str = "The output should start by introducing itself."
33
+ model: models.Model | models.KnownModelName = "openai:gpt-4o-mini"
34
+ include_input = True
35
+
36
+
37
+ class TrainState(TypedDict):
38
+ """State of the graph."""
39
+
40
+ output: TextTensor
41
+
42
+
43
+ class AgentNode(AgentModule):
44
+ """Agent node."""
45
+
46
+ def get_agent(self) -> Agent:
47
+ """Get agent instance."""
48
+ return Agent(
49
+ model=self.model or "openai:gpt-4o-mini",
50
+ system_prompt=self.system_prompt.text,
51
+ )
52
+
53
+
54
+ def main() -> None:
55
+ """Main function."""
56
+ if os.environ.get("LOGFIRE_TOKEN", None):
57
+ import logfire
58
+
59
+ logfire.configure(
60
+ send_to_logfire="if-token-present",
61
+ environment="development",
62
+ service_name="evals",
63
+ )
64
+ model = OpenAIModel(
65
+ model_name="llama3.2:1b",
66
+ provider=OpenAIProvider(base_url="http://localhost:11434/v1", api_key="ollama"),
67
+ )
68
+ # model="openai:gpt-4o-mini"
69
+
70
+ dataset = Dataset[TextTensor, TextTensor, Any](
71
+ cases=[
72
+ Case(
73
+ inputs=TextTensor("Hello, how are you?", model=model),
74
+ metadata={"language": "English"},
75
+ ),
76
+ Case(
77
+ inputs=TextTensor("こんにちは、元気ですか?", model=model),
78
+ metadata={"language": "Japanese"},
79
+ ),
80
+ ],
81
+ evaluators=[
82
+ ChineseLanguageJudge(model=model),
83
+ FormatJudge(model=model),
84
+ ],
85
+ )
86
+
87
+ graph = StateGraph(TrainState)
88
+ graph.add_node(
89
+ "agent",
90
+ AgentNode(
91
+ system_prompt=TextTensor(
92
+ "You are a helpful assistant.", requires_grad=True, model=model
93
+ ),
94
+ model=model,
95
+ ),
96
+ )
97
+ graph.add_edge(START, "agent")
98
+ graph.add_edge("agent", END)
99
+ compiled_graph = graph.compile()
100
+ optimizer = Optimizer(graph, model=model)
101
+ trainer = Trainer(
102
+ compiled_graph,
103
+ train_dataset=dataset,
104
+ optimizer=optimizer,
105
+ epochs=15,
106
+ )
107
+ trainer.train()
108
+
109
+
110
+ if __name__ == "__main__":
111
+ main()
@@ -14,7 +14,7 @@ dev = [
14
14
  "pytest-cov>=4.1.0",
15
15
  "ruff>=0.11.3",
16
16
  "smokeshow>=0.5.0",
17
- "types-requests",
17
+ "types-requests"
18
18
  ]
19
19
  docs = [
20
20
  "mkdocs",
@@ -26,17 +26,18 @@ docs = [
26
26
 
27
27
  [project]
28
28
  dependencies = [
29
- "logfire>=3.14.0",
30
- "pydantic-ai>=0.0.55",
29
+ "datasets>=3.5.0",
30
+ "langgraph>=0.4.5",
31
+ "logfire>=3.14.0",
32
+ "pydantic-ai>=0.2.4"
31
33
  ]
32
34
  description = "Add your description here"
33
35
  license = {file = "LICENSE"}
34
36
  name = "agentensor"
35
37
  readme = "README.md"
36
38
  requires-python = ">=3.12"
37
- version = "0.0.1"
38
39
  url = "https://github.com/ShaojieJiang/agentensor"
39
-
40
+ version = "0.0.3"
40
41
 
41
42
  [tool.coverage.report]
42
43
  exclude_lines = [
@@ -1,8 +1,8 @@
1
1
  """Test module for the Module class."""
2
2
 
3
- from unittest.mock import MagicMock, patch
3
+ from unittest.mock import AsyncMock, MagicMock, patch
4
4
  import pytest
5
- from agentensor.module import AgentModule, ModuleState
5
+ from agentensor.module import AgentModule
6
6
  from agentensor.tensor import TextTensor
7
7
 
8
8
 
@@ -15,28 +15,17 @@ def mock_agent():
15
15
  yield mock_agent
16
16
 
17
17
 
18
- def test_module_state_initialization(mock_agent):
19
- """Test ModuleState initialization."""
20
- input_tensor = TextTensor("test input")
21
- state = ModuleState(input=input_tensor)
22
-
23
- assert isinstance(state.input, TextTensor)
24
- assert state.input.text == "test input"
25
-
26
-
27
- def test_module_get_params(mock_agent):
18
+ def test_module_get_params():
28
19
  """Test AgentModule.get_params() method."""
29
20
 
30
21
  class TestModule(AgentModule):
31
- param1 = TextTensor("param1", requires_grad=True)
32
- param2 = TextTensor("param2", requires_grad=False)
33
- param3 = TextTensor("param3", requires_grad=True)
34
- non_param = "not a tensor"
35
-
36
- def __init__(self):
37
- pass
22
+ system_prompt: TextTensor = TextTensor("param1", requires_grad=True)
23
+ param2: TextTensor = TextTensor("param2", requires_grad=False)
24
+ param3: TextTensor = TextTensor("param3", requires_grad=True)
25
+ model: str = "openai:gpt-4o"
26
+ non_param: str = "not a tensor"
38
27
 
39
- def run(self, state: ModuleState) -> None:
28
+ def get_agent(self):
40
29
  """Dummy run method for testing."""
41
30
  pass
42
31
 
@@ -54,13 +43,10 @@ def test_module_get_params_empty(mock_agent):
54
43
  """Test AgentModule.get_params() with no parameters."""
55
44
 
56
45
  class EmptyModule(AgentModule):
57
- non_param = "not a tensor"
58
- param = TextTensor("param", requires_grad=False)
46
+ system_prompt: TextTensor = TextTensor("param", requires_grad=False)
47
+ non_param: str = "not a tensor"
59
48
 
60
- def __init__(self):
61
- pass
62
-
63
- def run(self, state: ModuleState) -> None:
49
+ def get_agent(self):
64
50
  """Dummy run method for testing."""
65
51
  pass
66
52
 
@@ -70,26 +56,20 @@ def test_module_get_params_empty(mock_agent):
70
56
  assert len(params) == 0
71
57
 
72
58
 
73
- def test_module_get_params_inheritance(mock_agent):
59
+ def test_module_get_params_inheritance():
74
60
  """Test AgentModule.get_params() with inheritance."""
75
61
 
76
62
  class ParentModule(AgentModule):
77
- parent_param = TextTensor("parent", requires_grad=True)
78
-
79
- def __init__(self):
80
- pass
63
+ system_prompt: TextTensor = TextTensor("parent", requires_grad=True)
81
64
 
82
- def run(self, state: ModuleState) -> None:
65
+ def get_agent(self):
83
66
  """Dummy run method for testing."""
84
67
  pass
85
68
 
86
69
  class ChildModule(ParentModule):
87
- child_param = TextTensor("child", requires_grad=True)
70
+ child_param: TextTensor = TextTensor("child", requires_grad=True)
88
71
 
89
- def __init__(self):
90
- super().__init__()
91
-
92
- def run(self, state: ModuleState) -> None:
72
+ def get_agent(self):
93
73
  """Dummy run method for testing."""
94
74
  pass
95
75
 
@@ -100,3 +80,23 @@ def test_module_get_params_inheritance(mock_agent):
100
80
  assert all(isinstance(p, TextTensor) for p in params)
101
81
  assert all(p.requires_grad for p in params)
102
82
  assert {p.text for p in params} == {"parent", "child"}
83
+
84
+
85
+ @pytest.mark.asyncio
86
+ async def test_module_call():
87
+ class TestModule(AgentModule):
88
+ system_prompt: TextTensor = TextTensor("system prompt", requires_grad=True)
89
+
90
+ def get_agent(self):
91
+ """Dummy run method for testing."""
92
+ mock_agent = AsyncMock()
93
+ run_output = MagicMock()
94
+ run_output.output = "Output text"
95
+ mock_agent.run.return_value = run_output
96
+ return mock_agent
97
+
98
+ module = TestModule()
99
+
100
+ result = await module({"output": TextTensor("Input text")})
101
+ assert isinstance(result["output"], TextTensor)
102
+ assert result["output"].text == "Output text"