goose-py 0.6.0__tar.gz → 0.7.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. goose_py-0.7.1/.envrc +1 -0
  2. goose_py-0.7.1/.github/workflows/publish.yml +48 -0
  3. goose_py-0.7.1/.gitignore +5 -0
  4. goose_py-0.7.1/.python-version +1 -0
  5. goose_py-0.7.1/.stubs/jsonpath_ng/__init__.pyi +14 -0
  6. goose_py-0.7.1/.stubs/litellm/__init__.pyi +61 -0
  7. goose_py-0.7.1/Makefile +9 -0
  8. goose_py-0.7.1/PKG-INFO +14 -0
  9. goose_py-0.7.1/goose/__init__.py +4 -0
  10. goose_py-0.6.0/goose/_agent.py → goose_py-0.7.1/goose/_internal/agent.py +14 -27
  11. goose_py-0.6.0/goose/_conversation.py → goose_py-0.7.1/goose/_internal/conversation.py +9 -12
  12. goose_py-0.6.0/goose/_flow.py → goose_py-0.7.1/goose/_internal/flow.py +8 -9
  13. goose_py-0.6.0/goose/_state.py → goose_py-0.7.1/goose/_internal/state.py +9 -19
  14. goose_py-0.6.0/goose/_store.py → goose_py-0.7.1/goose/_internal/store.py +2 -2
  15. goose_py-0.6.0/goose/_task.py → goose_py-0.7.1/goose/_internal/task.py +11 -18
  16. goose_py-0.7.1/goose/agent.py +28 -0
  17. goose_py-0.7.1/goose/flow.py +3 -0
  18. goose_py-0.7.1/goose/runs.py +4 -0
  19. goose_py-0.7.1/pyproject.toml +81 -0
  20. goose_py-0.7.1/tests/__init__.py +0 -0
  21. goose_py-0.7.1/tests/conftest.py +10 -0
  22. goose_py-0.7.1/tests/test_agent.py +66 -0
  23. goose_py-0.7.1/tests/test_complex_flow_arguments.py +22 -0
  24. goose_py-0.7.1/tests/test_downstream_task.py +35 -0
  25. goose_py-0.7.1/tests/test_jamming.py +76 -0
  26. goose_py-0.7.1/tests/test_looping.py +63 -0
  27. goose_py-0.7.1/tests/test_regenerate.py +48 -0
  28. goose_py-0.7.1/tests/test_state.py +49 -0
  29. goose_py-0.7.1/uv.lock +1057 -0
  30. goose_py-0.6.0/PKG-INFO +0 -31
  31. goose_py-0.6.0/goose/__init__.py +0 -5
  32. goose_py-0.6.0/pyproject.toml +0 -66
  33. {goose_py-0.6.0 → goose_py-0.7.1}/README.md +0 -0
  34. /goose_py-0.6.0/goose/_result.py → /goose_py-0.7.1/goose/_internal/result.py +0 -0
  35. {goose_py-0.6.0/goose → goose_py-0.7.1/goose/_internal}/types/__init__.py +0 -0
  36. {goose_py-0.6.0/goose → goose_py-0.7.1/goose/_internal}/types/agent.py +0 -0
  37. {goose_py-0.6.0 → goose_py-0.7.1}/goose/errors.py +0 -0
  38. {goose_py-0.6.0 → goose_py-0.7.1}/goose/py.typed +0 -0
goose_py-0.7.1/.envrc ADDED
@@ -0,0 +1 @@
1
+ dotenv
@@ -0,0 +1,48 @@
1
+ name: CI/CD
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ tags:
8
+ - "v*"
9
+ pull_request:
10
+ branches:
11
+ - main
12
+
13
+ jobs:
14
+ publish:
15
+ runs-on: ubuntu-latest
16
+
17
+ permissions:
18
+ id-token: write
19
+
20
+ steps:
21
+ - uses: actions/checkout@v4
22
+ with:
23
+ fetch-depth: 0
24
+
25
+ - name: Setup Python
26
+ uses: actions/setup-python@v5
27
+ with:
28
+ python-version-file: .python-version
29
+
30
+ - name: Setup UV
31
+ uses: astral-sh/setup-uv@v5
32
+ with:
33
+ version: "0.5.25"
34
+
35
+ - name: Initialize environment
36
+ run: uv sync --all-extras --dev
37
+
38
+ - name: Run tests
39
+ run: uv run pytest
40
+ env:
41
+ GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
42
+
43
+ - name: Build package
44
+ run: uv build
45
+
46
+ - name: Publish package
47
+ run: uv publish
48
+ continue-on-error: true
@@ -0,0 +1,5 @@
1
+ __pycache__
2
+ .env
3
+ poetry.lock
4
+ notebooks
5
+ dist
@@ -0,0 +1 @@
1
+ 3.12
@@ -0,0 +1,14 @@
1
+ from typing import Any
2
+
3
+ class Fields:
4
+ fields: tuple[str, ...]
5
+
6
+ class DatumInContext:
7
+ value: Any
8
+ fields: Fields
9
+ context: "DatumInContext"
10
+
11
+ class Jsonpath:
12
+ def find(self, /, value: Any) -> list[DatumInContext]: ...
13
+
14
+ def parse(jsonpath: str, /) -> Jsonpath: ...
@@ -0,0 +1,61 @@
1
+ from typing import Any, Literal, NotRequired, TypedDict
2
+
3
+ _LiteLLMGeminiModel = Literal[
4
+ "vertex_ai/gemini-1.5-flash",
5
+ "vertex_ai/gemini-1.5-pro",
6
+ "vertex_ai/gemini-1.5-flash-8b",
7
+ ]
8
+ _MessageRole = Literal["system", "user", "assistant"]
9
+
10
+ class _LiteLLMTextMessageContent(TypedDict):
11
+ type: Literal["text"]
12
+ text: str
13
+
14
+ class _LiteLLMMediaMessageContent(TypedDict):
15
+ type: Literal["image_url"]
16
+ image_url: str
17
+
18
+ class _LiteLLMCacheControl(TypedDict):
19
+ type: Literal["ephemeral"]
20
+
21
+ class _LiteLLMMessage(TypedDict):
22
+ role: _MessageRole
23
+ content: list[_LiteLLMTextMessageContent | _LiteLLMMediaMessageContent]
24
+ cache_control: NotRequired[_LiteLLMCacheControl]
25
+
26
+ class _LiteLLMResponseFormat(TypedDict):
27
+ type: Literal["json_object"]
28
+ response_schema: dict[str, Any] # must be a valid JSON schema
29
+ enforce_validation: NotRequired[bool]
30
+
31
+ class _LiteLLMModelResponseChoiceMessage:
32
+ role: Literal["assistant"]
33
+ content: str
34
+
35
+ class _LiteLLMModelResponseChoice:
36
+ finish_reason: Literal["stop"]
37
+ index: int
38
+ message: _LiteLLMModelResponseChoiceMessage
39
+
40
+ class _LiteLLMUsage:
41
+ completion_tokens: int
42
+ prompt_tokens: int
43
+ total_tokens: int
44
+
45
+ class ModelResponse:
46
+ id: str
47
+ created: int
48
+ model: _LiteLLMGeminiModel
49
+ object: Literal["chat.completion"]
50
+ system_fingerprint: str | None
51
+ choices: list[_LiteLLMModelResponseChoice]
52
+ usage: _LiteLLMUsage
53
+
54
+ async def acompletion(
55
+ *,
56
+ model: _LiteLLMGeminiModel,
57
+ messages: list[_LiteLLMMessage],
58
+ response_format: _LiteLLMResponseFormat | None = None,
59
+ max_tokens: int | None = None,
60
+ temperature: float = 1.0,
61
+ ) -> ModelResponse: ...
@@ -0,0 +1,9 @@
1
+ test:
2
+ uv run pytest tests
3
+ lint:
4
+ uv run ruff check .
5
+ uv run ruff format .
6
+ uv run pyright .
7
+ publish:
8
+ uv build
9
+ uv publish
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.4
2
+ Name: goose-py
3
+ Version: 0.7.1
4
+ Summary: A tool for AI workflows based on human-computer collaboration and structured output.
5
+ Author-email: Nash Taylor <nash@chelle.ai>, Joshua Cook <joshua@chelle.ai>, Michael Sankur <michael@chelle.ai>
6
+ Requires-Python: >=3.12
7
+ Requires-Dist: jsonpath-ng>=1.7.0
8
+ Requires-Dist: litellm>=1.56.5
9
+ Requires-Dist: pydantic>=2.8.2
10
+ Description-Content-Type: text/markdown
11
+
12
+ # Goose
13
+
14
+ Docs to come.
@@ -0,0 +1,4 @@
1
+ from goose._internal.agent import Agent
2
+ from goose._internal.flow import flow
3
+ from goose._internal.result import Result, TextResult
4
+ from goose._internal.task import task
@@ -6,8 +6,13 @@ from typing import Any, ClassVar, Protocol, TypedDict
6
6
  from litellm import acompletion
7
7
  from pydantic import BaseModel, computed_field
8
8
 
9
- from goose._result import Result, TextResult
10
- from goose.types.agent import AssistantMessage, GeminiModel, SystemMessage, UserMessage
9
+ from goose._internal.result import Result, TextResult
10
+ from goose._internal.types.agent import (
11
+ AssistantMessage,
12
+ GeminiModel,
13
+ SystemMessage,
14
+ UserMessage,
15
+ )
11
16
 
12
17
 
13
18
  class AgentResponseDump(TypedDict):
@@ -60,20 +65,12 @@ class AgentResponse[R: BaseModel | str](BaseModel):
60
65
  @computed_field
61
66
  @property
62
67
  def input_cost(self) -> float:
63
- return (
64
- self.INPUT_CENTS_PER_MILLION_TOKENS[self.model]
65
- * self.input_tokens
66
- / 1_000_000
67
- )
68
+ return self.INPUT_CENTS_PER_MILLION_TOKENS[self.model] * self.input_tokens / 1_000_000
68
69
 
69
70
  @computed_field
70
71
  @property
71
72
  def output_cost(self) -> float:
72
- return (
73
- self.OUTPUT_CENTS_PER_MILLION_TOKENS[self.model]
74
- * self.output_tokens
75
- / 1_000_000
76
- )
73
+ return self.OUTPUT_CENTS_PER_MILLION_TOKENS[self.model] * self.output_tokens / 1_000_000
77
74
 
78
75
  @computed_field
79
76
  @property
@@ -95,15 +92,9 @@ class AgentResponse[R: BaseModel | str](BaseModel):
95
92
  for part in message["content"]:
96
93
  if part["type"] == "image_url":
97
94
  part["image_url"] = "__MEDIA__"
98
- minimized_input_messages = [
99
- json.dumps(message) for message in minimized_input_messages
100
- ]
101
-
102
- output_message = (
103
- self.response.model_dump_json()
104
- if isinstance(self.response, BaseModel)
105
- else self.response
106
- )
95
+ minimized_input_messages = [json.dumps(message) for message in minimized_input_messages]
96
+
97
+ output_message = self.response.model_dump_json() if isinstance(self.response, BaseModel) else self.response
107
98
 
108
99
  return {
109
100
  "run_id": self.run_id,
@@ -156,9 +147,7 @@ class Agent:
156
147
 
157
148
  if response_model is TextResult:
158
149
  response = await acompletion(model=model.value, messages=rendered_messages)
159
- parsed_response = response_model.model_validate(
160
- {"text": response.choices[0].message.content}
161
- )
150
+ parsed_response = response_model.model_validate({"text": response.choices[0].message.content})
162
151
  else:
163
152
  response = await acompletion(
164
153
  model=model.value,
@@ -169,9 +158,7 @@ class Agent:
169
158
  "enforce_validation": True,
170
159
  },
171
160
  )
172
- parsed_response = response_model.model_validate_json(
173
- response.choices[0].message.content
174
- )
161
+ parsed_response = response_model.model_validate_json(response.choices[0].message.content)
175
162
 
176
163
  end_time = datetime.now()
177
164
  agent_response = AgentResponse(
@@ -1,7 +1,12 @@
1
1
  from pydantic import BaseModel
2
2
 
3
- from goose._result import Result
4
- from goose.types.agent import AssistantMessage, LLMMessage, SystemMessage, UserMessage
3
+ from goose._internal.result import Result
4
+ from goose._internal.types.agent import (
5
+ AssistantMessage,
6
+ LLMMessage,
7
+ SystemMessage,
8
+ UserMessage,
9
+ )
5
10
 
6
11
 
7
12
  class Conversation[R: Result](BaseModel):
@@ -19,18 +24,10 @@ class Conversation[R: Result](BaseModel):
19
24
  messages.append(self.context.render())
20
25
 
21
26
  for message_index in range(len(self.user_messages)):
22
- messages.append(
23
- AssistantMessage(
24
- text=self.result_messages[message_index].model_dump_json()
25
- ).render()
26
- )
27
+ messages.append(AssistantMessage(text=self.result_messages[message_index].model_dump_json()).render())
27
28
  messages.append(self.user_messages[message_index].render())
28
29
 
29
30
  if len(self.result_messages) > len(self.user_messages):
30
- messages.append(
31
- AssistantMessage(
32
- text=self.result_messages[-1].model_dump_json()
33
- ).render()
34
- )
31
+ messages.append(AssistantMessage(text=self.result_messages[-1].model_dump_json()).render())
35
32
 
36
33
  return messages
@@ -1,21 +1,20 @@
1
+ from collections.abc import AsyncIterator, Awaitable, Callable
1
2
  from contextlib import asynccontextmanager
2
3
  from types import CodeType
3
- from typing import AsyncIterator, Awaitable, Callable, Protocol, overload
4
+ from typing import Protocol, overload
4
5
 
5
- from goose._agent import Agent, IAgentLogger
6
- from goose._conversation import Conversation
7
- from goose._result import Result
8
- from goose._state import FlowRun, get_current_flow_run, set_current_flow_run
9
- from goose._store import IFlowRunStore, InMemoryFlowRunStore
6
+ from goose._internal.agent import Agent, IAgentLogger
7
+ from goose._internal.conversation import Conversation
8
+ from goose._internal.result import Result
9
+ from goose._internal.state import FlowRun, get_current_flow_run, set_current_flow_run
10
+ from goose._internal.store import IFlowRunStore, InMemoryFlowRunStore
10
11
  from goose.errors import Honk
11
12
 
12
13
 
13
14
  class IAdapter[ResultT: Result](Protocol):
14
15
  __code__: CodeType
15
16
 
16
- async def __call__(
17
- self, *, conversation: Conversation[ResultT], agent: Agent
18
- ) -> ResultT: ...
17
+ async def __call__(self, *, conversation: Conversation[ResultT], agent: Agent) -> ResultT: ...
19
18
 
20
19
 
21
20
  class Flow[**P]:
@@ -4,18 +4,18 @@ from typing import TYPE_CHECKING, Any, Self
4
4
 
5
5
  from pydantic import BaseModel
6
6
 
7
- from goose._agent import (
7
+ from goose._internal.agent import (
8
8
  Agent,
9
9
  IAgentLogger,
10
10
  SystemMessage,
11
11
  UserMessage,
12
12
  )
13
- from goose._conversation import Conversation
14
- from goose._result import Result
13
+ from goose._internal.conversation import Conversation
14
+ from goose._internal.result import Result
15
15
  from goose.errors import Honk
16
16
 
17
17
  if TYPE_CHECKING:
18
- from goose._task import Task
18
+ from goose._internal.task import Task
19
19
 
20
20
 
21
21
  @dataclass
@@ -97,23 +97,17 @@ class FlowRun:
97
97
  matching_nodes: list[NodeState[R]] = []
98
98
  for key, node_state in self._node_states.items():
99
99
  if key[0] == task.name:
100
- matching_nodes.append(
101
- NodeState[task.result_type].model_validate_json(node_state)
102
- )
100
+ matching_nodes.append(NodeState[task.result_type].model_validate_json(node_state))
103
101
  return sorted(matching_nodes, key=lambda node: node.index)
104
102
 
105
103
  def get[R: Result](self, *, task: "Task[Any, R]", index: int = 0) -> NodeState[R]:
106
- if (
107
- existing_node_state := self._node_states.get((task.name, index))
108
- ) is not None:
104
+ if (existing_node_state := self._node_states.get((task.name, index))) is not None:
109
105
  return NodeState[task.result_type].model_validate_json(existing_node_state)
110
106
  else:
111
107
  return NodeState[task.result_type](
112
108
  task_name=task.name,
113
109
  index=index,
114
- conversation=Conversation[task.result_type](
115
- user_messages=[], result_messages=[]
116
- ),
110
+ conversation=Conversation[task.result_type](user_messages=[], result_messages=[]),
117
111
  last_hash=0,
118
112
  )
119
113
 
@@ -143,9 +137,7 @@ class FlowRun:
143
137
  self._last_requested_indices = {}
144
138
  self._flow_name = flow_name
145
139
  self._id = run_id
146
- self._agent = Agent(
147
- flow_name=self.flow_name, run_id=self.id, logger=agent_logger
148
- )
140
+ self._agent = Agent(flow_name=self.flow_name, run_id=self.id, logger=agent_logger)
149
141
 
150
142
  def end(self) -> None:
151
143
  self._last_requested_indices = {}
@@ -177,9 +169,7 @@ class FlowRun:
177
169
  return flow_run
178
170
 
179
171
 
180
- _current_flow_run: ContextVar[FlowRun | None] = ContextVar(
181
- "current_flow_run", default=None
182
- )
172
+ _current_flow_run: ContextVar[FlowRun | None] = ContextVar("current_flow_run", default=None)
183
173
 
184
174
 
185
175
  def get_current_flow_run() -> FlowRun | None:
@@ -2,8 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Protocol
4
4
 
5
- from goose._flow import FlowRun
6
- from goose._state import FlowRunState
5
+ from goose._internal.flow import FlowRun
6
+ from goose._internal.state import FlowRunState
7
7
 
8
8
 
9
9
  class IFlowRunStore(Protocol):
@@ -1,11 +1,12 @@
1
- from typing import Awaitable, Callable, overload
2
-
3
- from goose._agent import Agent, GeminiModel, SystemMessage, UserMessage
4
- from goose._conversation import Conversation
5
- from goose._result import Result, TextResult
6
- from goose._state import FlowRun, NodeState, get_current_flow_run
1
+ from collections.abc import Awaitable, Callable
2
+ from typing import overload
3
+
4
+ from goose._internal.agent import Agent, GeminiModel, SystemMessage, UserMessage
5
+ from goose._internal.conversation import Conversation
6
+ from goose._internal.result import Result, TextResult
7
+ from goose._internal.state import FlowRun, NodeState, get_current_flow_run
8
+ from goose._internal.types.agent import AssistantMessage
7
9
  from goose.errors import Honk
8
- from goose.types.agent import AssistantMessage
9
10
 
10
11
 
11
12
  class Task[**P, R: Result]:
@@ -33,9 +34,7 @@ class Task[**P, R: Result]:
33
34
  def name(self) -> str:
34
35
  return self._generator.__name__
35
36
 
36
- async def generate(
37
- self, state: NodeState[R], *args: P.args, **kwargs: P.kwargs
38
- ) -> R:
37
+ async def generate(self, state: NodeState[R], *args: P.args, **kwargs: P.kwargs) -> R:
39
38
  state_hash = self.__hash_task_call(*args, **kwargs)
40
39
  if state_hash != state.last_hash:
41
40
  result = await self._generator(*args, **kwargs)
@@ -58,9 +57,7 @@ class Task[**P, R: Result]:
58
57
  node_state.set_context(context=context)
59
58
  node_state.add_user_message(message=user_message)
60
59
 
61
- result = await self.__adapt(
62
- conversation=node_state.conversation, agent=flow_run.agent
63
- )
60
+ result = await self.__adapt(conversation=node_state.conversation, agent=flow_run.agent)
64
61
  node_state.add_result(result=result)
65
62
  flow_run.add_node_state(node_state)
66
63
 
@@ -97,11 +94,7 @@ class Task[**P, R: Result]:
97
94
 
98
95
  def __hash_task_call(self, *args: P.args, **kwargs: P.kwargs) -> int:
99
96
  try:
100
- to_hash = str(
101
- tuple(args)
102
- + tuple(kwargs.values())
103
- + (self._generator.__code__, self._adapter_model)
104
- )
97
+ to_hash = str(tuple(args) + tuple(kwargs.values()) + (self._generator.__code__, self._adapter_model))
105
98
  return hash(to_hash)
106
99
  except TypeError:
107
100
  raise Honk(f"Unhashable argument to task {self.name}: {args} {kwargs}")
@@ -0,0 +1,28 @@
1
+ from goose._internal.agent import AgentResponse, IAgentLogger
2
+ from goose._internal.types.agent import (
3
+ AssistantMessage,
4
+ GeminiModel,
5
+ LLMMediaMessagePart,
6
+ LLMMessage,
7
+ LLMTextMessagePart,
8
+ MediaMessagePart,
9
+ SystemMessage,
10
+ TextMessagePart,
11
+ UserMediaContentType,
12
+ UserMessage,
13
+ )
14
+
15
+ __all__ = [
16
+ "AgentResponse",
17
+ "IAgentLogger",
18
+ "AssistantMessage",
19
+ "GeminiModel",
20
+ "LLMMediaMessagePart",
21
+ "LLMMessage",
22
+ "LLMTextMessagePart",
23
+ "MediaMessagePart",
24
+ "SystemMessage",
25
+ "TextMessagePart",
26
+ "UserMediaContentType",
27
+ "UserMessage",
28
+ ]
@@ -0,0 +1,3 @@
1
+ from goose._internal.flow import Flow
2
+
3
+ __all__ = ["Flow"]
@@ -0,0 +1,4 @@
1
+ from goose._internal.state import FlowRun, FlowRunState
2
+ from goose._internal.store import IFlowRunStore
3
+
4
+ __all__ = ["FlowRun", "FlowRunState", "IFlowRunStore"]
@@ -0,0 +1,81 @@
1
+ [project]
2
+ name = "goose-py"
3
+ version = "0.7.1"
4
+ description = "A tool for AI workflows based on human-computer collaboration and structured output."
5
+ readme = "README.md"
6
+ authors = [
7
+ { name = "Nash Taylor", email = "nash@chelle.ai" },
8
+ { name = "Joshua Cook", email = "joshua@chelle.ai" },
9
+ { name = "Michael Sankur", email = "michael@chelle.ai" },
10
+ ]
11
+ requires-python = ">=3.12"
12
+ dependencies = [
13
+ "jsonpath-ng>=1.7.0",
14
+ "litellm>=1.56.5",
15
+ "pydantic>=2.8.2",
16
+ ]
17
+
18
+ [build-system]
19
+ requires = ["hatchling"]
20
+ build-backend = "hatchling.build"
21
+
22
+ [tool.hatch.build.targets.wheel]
23
+ packages = ["goose"]
24
+
25
+ [dependency-groups]
26
+ dev = [
27
+ "pyright>=1.1.393",
28
+ "pytest>=8.3.4",
29
+ "pytest-asyncio>=0.25.3",
30
+ "pytest-mock>=3.14.0",
31
+ "ruff>=0.9.4",
32
+ ]
33
+
34
+
35
+ [tool.ruff]
36
+ exclude = [
37
+ "__init__.py",
38
+ ".venv",
39
+ "**/.venv",
40
+ "notebooks",
41
+ ".stubs",
42
+ ]
43
+ force-exclude = true
44
+ line-length = 120
45
+
46
+ [tool.ruff.lint]
47
+ select = [ "E", "F", "I", "UP" ]
48
+ ignore = [ "E501" ]
49
+
50
+
51
+ [tool.ruff.lint.isort]
52
+ known-first-party = [
53
+ "goose",
54
+ ]
55
+
56
+ [tool.ruff.lint.flake8-tidy-imports]
57
+ ban-relative-imports = "all"
58
+
59
+
60
+ [tool.pyright]
61
+ pythonVersion = "3.12"
62
+ typeCheckingMode = "strict"
63
+ reportMissingModuleSource = false
64
+ useLibraryCodeForTypes = false
65
+ reportUnknownMemberType = false
66
+ reportUnknownVariableType = false
67
+ stubPath = ".stubs"
68
+ venvPath = "."
69
+ venv = ".venv"
70
+ exclude = ["goose/__init__.py", ".venv", "notebooks"]
71
+
72
+ [tool.pytest.ini_options]
73
+ filterwarnings = [
74
+ "ignore::DeprecationWarning",
75
+ "ignore::SyntaxWarning",
76
+ "ignore::UserWarning",
77
+ ]
78
+ addopts = "-v"
79
+ testpaths = ["tests"]
80
+ pythonpath = ["."]
81
+ python_files = "test_*.py"
File without changes
@@ -0,0 +1,10 @@
1
+ import asyncio
2
+
3
+ import pytest
4
+
5
+
6
+ @pytest.fixture(scope="session")
7
+ def event_loop():
8
+ loop = asyncio.get_event_loop()
9
+ yield loop
10
+ loop.close()
@@ -0,0 +1,66 @@
1
+ from unittest.mock import Mock
2
+
3
+ import pytest
4
+ from pytest_mock import MockerFixture
5
+
6
+ from goose import TextResult, flow, task
7
+ from goose._internal.agent import Agent, AgentResponse, IAgentLogger
8
+ from goose.agent import GeminiModel, TextMessagePart, UserMessage
9
+
10
+
11
+ class MockLiteLLMResponse:
12
+ def __init__(self, *, response: str, prompt_tokens: int, completion_tokens: int) -> None:
13
+ self.choices = [Mock(message=Mock(content=response))]
14
+ self.usage = Mock(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
15
+
16
+
17
+ @pytest.fixture
18
+ def mock_litellm(mocker: MockerFixture) -> Mock:
19
+ return mocker.patch(
20
+ "goose._internal.agent.acompletion",
21
+ return_value=MockLiteLLMResponse(response="Hello", prompt_tokens=10, completion_tokens=10),
22
+ )
23
+
24
+
25
+ @task
26
+ async def use_agent(*, agent: Agent) -> TextResult:
27
+ return await agent(
28
+ messages=[UserMessage(parts=[TextMessagePart(text="Hello")])],
29
+ model=GeminiModel.FLASH_8B,
30
+ task_name="greet",
31
+ )
32
+
33
+
34
+ @flow
35
+ async def agent_flow(*, agent: Agent) -> None:
36
+ await use_agent(agent=agent)
37
+
38
+
39
+ class CustomLogger(IAgentLogger):
40
+ logged_responses: list[AgentResponse[TextResult]] = []
41
+
42
+ async def __call__(self, *, response: AgentResponse[TextResult]) -> None:
43
+ self.logged_responses.append(response)
44
+
45
+
46
+ @flow(agent_logger=CustomLogger())
47
+ async def agent_flow_with_custom_logger(*, agent: Agent) -> None:
48
+ await use_agent(agent=agent)
49
+
50
+
51
+ @pytest.mark.asyncio
52
+ @pytest.mark.usefixtures("mock_litellm")
53
+ async def test_agent() -> None:
54
+ async with agent_flow.start_run(run_id="1") as run:
55
+ await agent_flow.generate(agent=run.agent)
56
+
57
+ assert run.get(task=use_agent).result.text == "Hello"
58
+
59
+
60
+ @pytest.mark.asyncio
61
+ @pytest.mark.usefixtures("mock_litellm")
62
+ async def test_agent_custom_logger() -> None:
63
+ async with agent_flow_with_custom_logger.start_run(run_id="1") as run:
64
+ await agent_flow_with_custom_logger.generate(agent=run.agent)
65
+
66
+ assert len(CustomLogger.logged_responses) == 1
@@ -0,0 +1,22 @@
1
+ import pytest
2
+ from pydantic import BaseModel
3
+
4
+ from goose import Agent, flow
5
+
6
+
7
+ class MyMessage(BaseModel):
8
+ text: str
9
+
10
+
11
+ @flow
12
+ async def my_flow(*, message: MyMessage, agent: Agent) -> None:
13
+ pass
14
+
15
+
16
+ @pytest.mark.asyncio
17
+ async def test_my_flow() -> None:
18
+ async with my_flow.start_run(run_id="1") as run:
19
+ await my_flow.generate(message=MyMessage(text="Hello"), agent=run.agent)
20
+
21
+ async with my_flow.start_run(run_id="1") as run:
22
+ await my_flow.generate(message=MyMessage(text="Hello"), agent=run.agent)