agentensor 0.0.2__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {agentensor-0.0.2 → agentensor-0.0.4}/.bumpversion.cfg +1 -1
  2. {agentensor-0.0.2 → agentensor-0.0.4}/PKG-INFO +6 -2
  3. {agentensor-0.0.2 → agentensor-0.0.4}/agentensor/loss.py +1 -1
  4. agentensor-0.0.4/agentensor/module.py +62 -0
  5. agentensor-0.0.4/agentensor/optim.py +57 -0
  6. {agentensor-0.0.2 → agentensor-0.0.4}/agentensor/tensor.py +32 -12
  7. {agentensor-0.0.2 → agentensor-0.0.4}/agentensor/train.py +8 -11
  8. {agentensor-0.0.2 → agentensor-0.0.4}/examples/evaluate.py +43 -50
  9. agentensor-0.0.4/examples/train.py +118 -0
  10. {agentensor-0.0.2 → agentensor-0.0.4}/pyproject.toml +6 -2
  11. {agentensor-0.0.2 → agentensor-0.0.4}/tests/test_loss.py +5 -1
  12. agentensor-0.0.4/tests/test_module.py +119 -0
  13. agentensor-0.0.4/tests/test_optim.py +130 -0
  14. agentensor-0.0.4/tests/test_tensor.py +156 -0
  15. {agentensor-0.0.2 → agentensor-0.0.4}/tests/test_train.py +54 -47
  16. {agentensor-0.0.2 → agentensor-0.0.4}/uv.lock +481 -46
  17. agentensor-0.0.2/agentensor/module.py +0 -26
  18. agentensor-0.0.2/agentensor/optim.py +0 -41
  19. agentensor-0.0.2/examples/train.py +0 -109
  20. agentensor-0.0.2/tests/test_module.py +0 -102
  21. agentensor-0.0.2/tests/test_optim.py +0 -105
  22. agentensor-0.0.2/tests/test_tensor.py +0 -133
  23. {agentensor-0.0.2 → agentensor-0.0.4}/.github/workflows/after-ci.yml +0 -0
  24. {agentensor-0.0.2 → agentensor-0.0.4}/.github/workflows/ci.yml +0 -0
  25. {agentensor-0.0.2 → agentensor-0.0.4}/.gitignore +0 -0
  26. {agentensor-0.0.2 → agentensor-0.0.4}/.pre-commit-config.yaml +0 -0
  27. {agentensor-0.0.2 → agentensor-0.0.4}/.python-version +0 -0
  28. {agentensor-0.0.2 → agentensor-0.0.4}/LICENSE +0 -0
  29. {agentensor-0.0.2 → agentensor-0.0.4}/Makefile +0 -0
  30. {agentensor-0.0.2 → agentensor-0.0.4}/README.md +0 -0
  31. {agentensor-0.0.2 → agentensor-0.0.4}/agentensor/__init__.py +0 -0
  32. {agentensor-0.0.2 → agentensor-0.0.4}/mkdocs.yml +0 -0
@@ -1,5 +1,5 @@
1
1
  [bumpversion]
2
- current_version = 0.0.2
2
+ current_version = 0.0.4
3
3
  commit = True
4
4
  tag = True
5
5
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: agentensor
3
- Version: 0.0.2
3
+ Version: 0.0.4
4
4
  Summary: Add your description here
5
5
  License: MIT License
6
6
 
@@ -26,8 +26,12 @@ License: MIT License
26
26
  License-File: LICENSE
27
27
  Requires-Python: >=3.12
28
28
  Requires-Dist: datasets>=3.5.0
29
+ Requires-Dist: langchain-ollama>=0.3.3
30
+ Requires-Dist: langchain-openai>=0.3.18
31
+ Requires-Dist: langchain>=0.3.25
32
+ Requires-Dist: langgraph>=0.4.5
29
33
  Requires-Dist: logfire>=3.14.0
30
- Requires-Dist: pydantic-ai>=0.1.3
34
+ Requires-Dist: pydantic-ai>=0.2.4
31
35
  Description-Content-Type: text/markdown
32
36
 
33
37
  # AgenTensor
@@ -16,7 +16,7 @@ class LLMTensorJudge(Evaluator[TextTensor, TextTensor, Any]):
16
16
  """
17
17
 
18
18
  rubric: str
19
- model: models.KnownModelName | None = None
19
+ model: models.Model | models.KnownModelName | None = None
20
20
  include_input: bool = True
21
21
 
22
22
  async def evaluate(
@@ -0,0 +1,62 @@
1
+ """Module class."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any
5
+ from langchain.chat_models import init_chat_model
6
+ from langchain_core.language_models import BaseChatModel
7
+ from langchain_core.messages import HumanMessage
8
+ from langgraph.graph.graph import CompiledGraph
9
+ from pydantic import BaseModel, ConfigDict
10
+ from pydantic_ai.exceptions import UnexpectedModelBehavior
11
+ from agentensor.tensor import TextTensor
12
+
13
+
14
+ class AgentModule(BaseModel, ABC):
15
+ """Agent module."""
16
+
17
+ model_config = ConfigDict(arbitrary_types_allowed=True)
18
+
19
+ system_prompt: TextTensor
20
+ llm: str | BaseChatModel = "gpt-4o-mini"
21
+
22
+ def model_post_init(self, __context: Any) -> None:
23
+ """Post initialization hook."""
24
+ if isinstance(self.llm, str): # pragma: no cover
25
+ self.llm = init_chat_model(self.llm)
26
+
27
+ def get_params(self) -> list[TextTensor]:
28
+ """Get the parameters of the module."""
29
+ params = []
30
+ for field_name in self.__class__.model_fields.keys():
31
+ field = getattr(self, field_name)
32
+ if isinstance(field, TextTensor) and field.requires_grad:
33
+ params.append(field)
34
+ return params
35
+
36
+ async def __call__(self, state: dict) -> dict:
37
+ """Run the agent node."""
38
+ assert state["output"]
39
+ try:
40
+ result = await self.agent.ainvoke(
41
+ {"messages": [HumanMessage(content=state["output"].text)]}
42
+ )
43
+ output = str(
44
+ result.get("structured_response", result["messages"][-1].content)
45
+ ) # prioritize structured response over raw response
46
+ except UnexpectedModelBehavior: # pragma: no cover
47
+ output = "Error"
48
+
49
+ output_tensor = TextTensor(
50
+ output,
51
+ parents=[state["output"], self.system_prompt],
52
+ requires_grad=True,
53
+ model=self.llm,
54
+ )
55
+
56
+ return {"output": output_tensor}
57
+
58
+ @property
59
+ @abstractmethod
60
+ def agent(self) -> CompiledGraph:
61
+ """Get agent instance."""
62
+ pass # pragma: no cover
@@ -0,0 +1,57 @@
1
+ """Optimizer module."""
2
+
3
+ from langchain.chat_models import init_chat_model
4
+ from langchain_core.language_models import BaseChatModel
5
+ from langchain_core.messages import HumanMessage
6
+ from langgraph.graph import StateGraph
7
+ from langgraph.graph.graph import CompiledGraph
8
+ from langgraph.prebuilt import create_react_agent
9
+ from agentensor.module import AgentModule
10
+ from agentensor.tensor import TextTensor
11
+
12
+
13
+ class Optimizer:
14
+ """Optimizer class."""
15
+
16
+ def __init__(
17
+ self,
18
+ graph: StateGraph,
19
+ model: str | BaseChatModel = "gpt-4o-mini",
20
+ ) -> None:
21
+ """Initialize the optimizer."""
22
+ self.params: list[TextTensor] = [
23
+ param
24
+ for node in graph.nodes.values()
25
+ if isinstance(node.runnable.afunc, AgentModule) # type: ignore[attr-defined]
26
+ for param in node.runnable.afunc.get_params() # type: ignore[attr-defined]
27
+ ]
28
+ if isinstance(model, str):
29
+ self.model = init_chat_model(model)
30
+ else: # pragma: no cover
31
+ self.model = model
32
+
33
+ def step(self) -> None:
34
+ """Step the optimizer."""
35
+ for param in self.params:
36
+ if not param.text_grad:
37
+ continue
38
+ param.text = self.optimize(param.text, param.text_grad)
39
+
40
+ def zero_grad(self) -> None:
41
+ """Zero the gradients."""
42
+ for param in self.params:
43
+ param.zero_grad()
44
+
45
+ def optimize(self, text: str, grad: str) -> str:
46
+ """Optimize the text."""
47
+ result = self.agent.invoke(
48
+ {"messages": [HumanMessage(content=f"Feedback: {grad}\nText: {text}")]}
49
+ )
50
+ return result["messages"][-1].content
51
+
52
+ @property
53
+ def agent(self) -> CompiledGraph:
54
+ """Get the agent."""
55
+ return create_react_agent(
56
+ self.model, tools=[], prompt="Rewrite the system prompt given the feedback."
57
+ )
@@ -1,7 +1,11 @@
1
1
  """Example module."""
2
2
 
3
3
  from __future__ import annotations
4
- from pydantic_ai import Agent, models
4
+ from langchain.chat_models import init_chat_model
5
+ from langchain_core.language_models import BaseChatModel
6
+ from langchain_core.messages import HumanMessage
7
+ from langgraph.graph.graph import CompiledGraph
8
+ from langgraph.prebuilt import create_react_agent
5
9
 
6
10
 
7
11
  class TextTensor:
@@ -12,17 +16,17 @@ class TextTensor:
12
16
  text: str,
13
17
  parents: list[TextTensor] | None = None,
14
18
  requires_grad: bool = False,
15
- model: models.Model | models.KnownModelName | str | None = None,
19
+ model: str | BaseChatModel = "gpt-4o-mini",
16
20
  ) -> None:
17
21
  """Initialize a TextTensor."""
18
22
  self.text = text
19
23
  self.requires_grad = requires_grad
20
24
  self.gradients: list[str] = []
21
- self.agent = Agent(
22
- model=model or "openai:gpt-4o-mini",
23
- system_prompt="Answer the user's question.",
24
- )
25
25
  self.parents: list[TextTensor] = parents or []
26
+ if isinstance(model, str):
27
+ self.model = init_chat_model(model)
28
+ else:
29
+ self.model = model
26
30
 
27
31
  def backward(self, grad: str = "") -> None:
28
32
  """Backward pass for the TextTensor.
@@ -43,12 +47,21 @@ class TextTensor:
43
47
 
44
48
  def calc_grad(self, input_text: str, output_text: str, grad: str) -> str:
45
49
  """Calculate the gradient for the TextTensor."""
46
- return self.agent.run_sync(
47
- f"Here is the input: \n\n>{input_text}\n\nI got this "
48
- f"output: \n\n>{output_text}\n\nHere is the feedback: \n\n"
49
- f">{grad}\n\nHow should I improve the input to get a "
50
- f"better output?"
51
- ).data
50
+ result = self.agent.invoke(
51
+ {
52
+ "messages": [
53
+ HumanMessage(
54
+ content=(
55
+ f"Here is the input: \n\n>{input_text}\n\nI got this "
56
+ f"output: \n\n>{output_text}\n\nHere is the feedback: \n\n"
57
+ f">{grad}\n\nHow should I improve the input to get a "
58
+ f"better output?"
59
+ )
60
+ )
61
+ ]
62
+ }
63
+ )
64
+ return result["messages"][-1].content
52
65
 
53
66
  @property
54
67
  def text_grad(self) -> str:
@@ -62,3 +75,10 @@ class TextTensor:
62
75
  def __str__(self) -> str:
63
76
  """Return the text as a string."""
64
77
  return self.text
78
+
79
+ @property
80
+ def agent(self) -> CompiledGraph:
81
+ """Get the agent."""
82
+ return create_react_agent(
83
+ self.model, tools=[], prompt="Answer the user's question."
84
+ )
@@ -1,11 +1,9 @@
1
1
  """Trainer."""
2
2
 
3
3
  from typing import Any, Literal
4
+ from langgraph.graph.graph import CompiledGraph
4
5
  from pydantic_evals import Dataset
5
6
  from pydantic_evals.reporting import EvaluationReport
6
- from pydantic_graph import Graph
7
- from pydantic_graph.nodes import DepsT, StateT
8
- from agentensor.module import AgentModule, ModuleState
9
7
  from agentensor.optim import Optimizer
10
8
  from agentensor.tensor import TextTensor
11
9
 
@@ -15,9 +13,8 @@ class Trainer:
15
13
 
16
14
  def __init__(
17
15
  self,
18
- graph: Graph[StateT, DepsT, TextTensor],
19
- graph_state: ModuleState,
20
- start_node: type[AgentModule],
16
+ graph: CompiledGraph,
17
+ graph_recursion_limit: int = 25,
21
18
  train_dataset: Dataset[TextTensor, TextTensor, Any] | None = None,
22
19
  eval_dataset: Dataset[TextTensor, TextTensor, Any] | None = None,
23
20
  test_dataset: Dataset[TextTensor, TextTensor, Any] | None = None,
@@ -27,8 +24,7 @@ class Trainer:
27
24
  ):
28
25
  """Initialize the trainer."""
29
26
  self.graph = graph
30
- self.graph_state = graph_state
31
- self.start_node = start_node
27
+ self.graph_recursion_limit = graph_recursion_limit
32
28
  self.optimizer = optimizer
33
29
  self.epochs = epochs
34
30
  self.stop_threshold = stop_threshold
@@ -38,9 +34,10 @@ class Trainer:
38
34
 
39
35
  async def forward(self, x: TextTensor) -> TextTensor:
40
36
  """Forward the graph."""
41
- self.graph_state.input = x
42
- result = await self.graph.run(self.start_node(), state=self.graph_state) # type: ignore[arg-type]
43
- return result.output
37
+ result = await self.graph.ainvoke(
38
+ {"output": x}, {"recursion_limit": self.graph_recursion_limit}
39
+ )
40
+ return result["output"]
44
41
 
45
42
  def train(self) -> None:
46
43
  """Train the graph."""
@@ -3,16 +3,17 @@
3
3
  from __future__ import annotations
4
4
  import json
5
5
  from dataclasses import dataclass
6
+ from typing import TypedDict
6
7
  from datasets import load_dataset
8
+ from langchain.chat_models import init_chat_model
9
+ from langchain_core.language_models import BaseChatModel
10
+ from langgraph.graph import END, START, StateGraph
11
+ from langgraph.graph.graph import CompiledGraph
12
+ from langgraph.prebuilt import create_react_agent
7
13
  from pydantic import BaseModel
8
- from pydantic_ai import Agent, models
9
- from pydantic_ai.exceptions import UnexpectedModelBehavior
10
- from pydantic_ai.models.openai import OpenAIModel
11
- from pydantic_ai.providers.openai import OpenAIProvider
12
14
  from pydantic_evals import Case, Dataset
13
15
  from pydantic_evals.evaluators import EvaluationReason, Evaluator, EvaluatorContext
14
- from pydantic_graph import End, Graph, GraphRunContext
15
- from agentensor.module import AgentModule, ModuleState
16
+ from agentensor.module import AgentModule
16
17
  from agentensor.tensor import TextTensor
17
18
  from agentensor.train import Trainer
18
19
 
@@ -48,11 +49,10 @@ class MultiLabelClassificationAccuracy(Evaluator):
48
49
  return set(output) == set(expected) # type: ignore[arg-type]
49
50
 
50
51
 
51
- @dataclass
52
- class EvaluateState(ModuleState):
52
+ class EvaluateState(TypedDict):
53
53
  """State of the graph."""
54
54
 
55
- agent_prompt: TextTensor = TextTensor(text="")
55
+ output: TextTensor
56
56
 
57
57
 
58
58
  class ClassificationResults(BaseModel, use_attribute_docstrings=True):
@@ -73,12 +73,15 @@ class HFMultiClassClassificationTask:
73
73
  self,
74
74
  task_repo: str,
75
75
  evaluators: list[Evaluator],
76
- model: models.Model | models.KnownModelName | str | None = None,
76
+ model: BaseChatModel | str = "gpt-4o-mini",
77
77
  ) -> None:
78
78
  """Initialize the multi-class classification task."""
79
79
  self.task_repo = task_repo
80
80
  self.evaluators = evaluators
81
- self.model = model
81
+ if isinstance(model, str):
82
+ self.model = init_chat_model(model)
83
+ else:
84
+ self.model = model
82
85
  self.dataset = self._prepare_dataset()
83
86
 
84
87
  def _prepare_dataset(self) -> dict[str, Dataset]:
@@ -101,60 +104,50 @@ class HFMultiClassClassificationTask:
101
104
  return dataset
102
105
 
103
106
 
104
- class AgentNode(AgentModule[EvaluateState, None, TextTensor]):
107
+ class AgentNode(AgentModule):
105
108
  """Agent node."""
106
109
 
107
- async def run(self, ctx: GraphRunContext[EvaluateState, None]) -> End[TextTensor]: # type: ignore[override]
108
- """Run the agent node."""
109
- agent = Agent(
110
- model=model,
111
- system_prompt=ctx.state.agent_prompt.text,
112
- output_type=ClassificationResults,
113
- )
114
- assert ctx.state.input
115
- try:
116
- result = await agent.run(ctx.state.input.text)
117
- output = result.output
118
- except UnexpectedModelBehavior:
119
- output = "Error" # type: ignore[assignment]
120
-
121
- output_tensor = TextTensor(
122
- str(output),
123
- parents=[ctx.state.input, ctx.state.agent_prompt],
124
- requires_grad=True,
110
+ @property
111
+ def agent(self) -> CompiledGraph:
112
+ """Get agent instance."""
113
+ return create_react_agent(
114
+ self.llm,
115
+ tools=[],
116
+ prompt=self.system_prompt.text,
117
+ response_format=ClassificationResults,
125
118
  )
126
119
 
127
- return End(output_tensor)
128
-
129
120
 
130
121
  if __name__ == "__main__":
131
- model = OpenAIModel(
132
- model_name="llama3.2:1b",
133
- provider=OpenAIProvider(base_url="http://localhost:11434/v1", api_key="ollama"),
134
- )
135
- # model = "openai:gpt-4o-mini"
122
+ model = init_chat_model("llama3.2:1b", model_provider="ollama")
123
+ # model = "gpt-4o-mini"
136
124
 
137
125
  task = HFMultiClassClassificationTask(
138
126
  task_repo="knowledgator/events_classification_biotech",
139
127
  evaluators=[GenerationTimeout(), MultiLabelClassificationAccuracy()],
140
128
  model=model,
141
129
  )
142
- state = EvaluateState(
143
- agent_prompt=TextTensor(
144
- (
145
- "Classify the following text into one of the following "
146
- "categories: [expanding industry, new initiatives or programs, "
147
- "article publication, other]"
130
+ graph = StateGraph(EvaluateState)
131
+ graph.add_node(
132
+ "agent",
133
+ AgentNode(
134
+ system_prompt=TextTensor(
135
+ (
136
+ "Classify the following text into one of the following "
137
+ "categories: [expanding industry, new initiatives or programs, "
138
+ "article publication, other]"
139
+ ),
140
+ requires_grad=True,
141
+ model=model,
148
142
  ),
149
- requires_grad=True,
150
- model=model,
151
- )
143
+ llm=model,
144
+ ),
152
145
  )
153
- graph = Graph(nodes=[AgentNode])
146
+ graph.add_edge(START, "agent")
147
+ graph.add_edge("agent", END)
148
+ compiled_graph = graph.compile()
154
149
  trainer = Trainer(
155
- graph,
156
- state,
157
- AgentNode, # type: ignore[arg-type]
150
+ compiled_graph,
158
151
  train_dataset=task.dataset["train"],
159
152
  test_dataset=task.dataset["test"],
160
153
  )
@@ -0,0 +1,118 @@
1
+ """Example usage of agentensor."""
2
+
3
+ from __future__ import annotations
4
+ import os
5
+ from dataclasses import dataclass
6
+ from typing import Any, TypedDict
7
+ from langchain.chat_models import init_chat_model
8
+ from langgraph.graph import END, START, StateGraph
9
+ from langgraph.graph.graph import CompiledGraph
10
+ from langgraph.prebuilt import create_react_agent
11
+ from pydantic_ai import models
12
+ from pydantic_ai.models.openai import OpenAIModel
13
+ from pydantic_ai.providers.openai import OpenAIProvider
14
+ from pydantic_evals import Case, Dataset
15
+ from agentensor.loss import LLMTensorJudge
16
+ from agentensor.module import AgentModule
17
+ from agentensor.optim import Optimizer
18
+ from agentensor.tensor import TextTensor
19
+ from agentensor.train import Trainer
20
+
21
+
22
+ @dataclass
23
+ class ChineseLanguageJudge(LLMTensorJudge):
24
+ """Chinese language judge."""
25
+
26
+ rubric: str = "The output should be in Chinese."
27
+ model: models.Model | models.KnownModelName = "openai:gpt-4o-mini"
28
+ include_input = True
29
+
30
+
31
+ @dataclass
32
+ class FormatJudge(LLMTensorJudge):
33
+ """Format judge."""
34
+
35
+ rubric: str = "The output should start by introducing itself."
36
+ model: models.Model | models.KnownModelName = "openai:gpt-4o-mini"
37
+ include_input = True
38
+
39
+
40
+ class TrainState(TypedDict):
41
+ """State of the graph."""
42
+
43
+ output: TextTensor
44
+
45
+
46
+ class AgentNode(AgentModule):
47
+ """Agent node."""
48
+
49
+ @property
50
+ def agent(self) -> CompiledGraph:
51
+ """Get agent instance."""
52
+ return create_react_agent(
53
+ self.llm,
54
+ tools=[],
55
+ prompt=self.system_prompt.text,
56
+ )
57
+
58
+
59
+ def main() -> None:
60
+ """Main function."""
61
+ if os.environ.get("LOGFIRE_TOKEN", None):
62
+ import logfire
63
+
64
+ logfire.configure(
65
+ send_to_logfire="if-token-present",
66
+ environment="development",
67
+ service_name="evals",
68
+ )
69
+ eval_model = OpenAIModel(
70
+ model_name="llama3.2:1b",
71
+ provider=OpenAIProvider(base_url="http://localhost:11434/v1", api_key="ollama"),
72
+ )
73
+ model = init_chat_model("llama3.2:1b", model_provider="ollama")
74
+ # eval_model = "gpt-4o-mini"
75
+ # model = "gpt-4o-mini"
76
+
77
+ dataset = Dataset[TextTensor, TextTensor, Any](
78
+ cases=[
79
+ Case(
80
+ inputs=TextTensor("Hello, how are you?", model=model),
81
+ metadata={"language": "English"},
82
+ ),
83
+ Case(
84
+ inputs=TextTensor("こんにちは、元気ですか?", model=model),
85
+ metadata={"language": "Japanese"},
86
+ ),
87
+ ],
88
+ evaluators=[
89
+ ChineseLanguageJudge(model=eval_model),
90
+ FormatJudge(model=eval_model),
91
+ ],
92
+ )
93
+
94
+ graph = StateGraph(TrainState)
95
+ graph.add_node(
96
+ "agent",
97
+ AgentNode(
98
+ system_prompt=TextTensor(
99
+ "You are a helpful assistant.", requires_grad=True, model=model
100
+ ),
101
+ llm=model,
102
+ ),
103
+ )
104
+ graph.add_edge(START, "agent")
105
+ graph.add_edge("agent", END)
106
+ compiled_graph = graph.compile()
107
+ optimizer = Optimizer(graph, model=model)
108
+ trainer = Trainer(
109
+ compiled_graph,
110
+ train_dataset=dataset,
111
+ optimizer=optimizer,
112
+ epochs=15,
113
+ )
114
+ trainer.train()
115
+
116
+
117
+ if __name__ == "__main__":
118
+ main()
@@ -27,8 +27,12 @@ docs = [
27
27
  [project]
28
28
  dependencies = [
29
29
  "datasets>=3.5.0",
30
+ "langchain>=0.3.25",
31
+ "langchain-ollama>=0.3.3",
32
+ "langchain-openai>=0.3.18",
33
+ "langgraph>=0.4.5",
30
34
  "logfire>=3.14.0",
31
- "pydantic-ai>=0.1.3"
35
+ "pydantic-ai>=0.2.4"
32
36
  ]
33
37
  description = "Add your description here"
34
38
  license = {file = "LICENSE"}
@@ -36,7 +40,7 @@ name = "agentensor"
36
40
  readme = "README.md"
37
41
  requires-python = ">=3.12"
38
42
  url = "https://github.com/ShaojieJiang/agentensor"
39
- version = "0.0.2"
43
+ version = "0.0.4"
40
44
 
41
45
  [tool.coverage.report]
42
46
  exclude_lines = [
@@ -33,8 +33,12 @@ def mock_judge_output():
33
33
 
34
34
 
35
35
  @pytest.fixture
36
- def evaluator_context(mock_openai):
36
+ @patch("agentensor.tensor.init_chat_model")
37
+ def evaluator_context(mock_init_chat_model, mock_openai):
37
38
  """Create a test evaluator context."""
39
+ # Mock the model initialization
40
+ mock_init_chat_model.return_value = MagicMock()
41
+
38
42
  return EvaluatorContext(
39
43
  name="test_evaluator",
40
44
  inputs=TextTensor(text="test input"),