goose-py 0.5.1__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
goose/__init__.py CHANGED
@@ -0,0 +1,4 @@
1
+ from goose._internal.agent import Agent
2
+ from goose._internal.flow import flow
3
+ from goose._internal.result import Result, TextResult
4
+ from goose._internal.task import task
@@ -0,0 +1,201 @@
1
+ import json
2
+ import logging
3
+ from datetime import datetime
4
+ from typing import Any, ClassVar, Protocol, TypedDict
5
+
6
+ from litellm import acompletion
7
+ from pydantic import BaseModel, computed_field
8
+
9
+ from goose._internal.result import Result, TextResult
10
+ from goose._internal.types.agent import (
11
+ AssistantMessage,
12
+ GeminiModel,
13
+ SystemMessage,
14
+ UserMessage,
15
+ )
16
+
17
+
18
+ class AgentResponseDump(TypedDict):
19
+ run_id: str
20
+ flow_name: str
21
+ task_name: str
22
+ model: str
23
+ system_message: str
24
+ input_messages: list[str]
25
+ output_message: str
26
+ input_cost: float
27
+ output_cost: float
28
+ total_cost: float
29
+ input_tokens: int
30
+ output_tokens: int
31
+ start_time: datetime
32
+ end_time: datetime
33
+ duration_ms: int
34
+
35
+
36
+ class AgentResponse[R: BaseModel | str](BaseModel):
37
+ INPUT_CENTS_PER_MILLION_TOKENS: ClassVar[dict[GeminiModel, float]] = {
38
+ GeminiModel.FLASH_8B: 30,
39
+ GeminiModel.FLASH: 15,
40
+ GeminiModel.PRO: 500,
41
+ }
42
+ OUTPUT_CENTS_PER_MILLION_TOKENS: ClassVar[dict[GeminiModel, float]] = {
43
+ GeminiModel.FLASH_8B: 30,
44
+ GeminiModel.FLASH: 15,
45
+ GeminiModel.PRO: 500,
46
+ }
47
+
48
+ response: R
49
+ run_id: str
50
+ flow_name: str
51
+ task_name: str
52
+ model: GeminiModel
53
+ system: SystemMessage | None = None
54
+ input_messages: list[UserMessage | AssistantMessage]
55
+ input_tokens: int
56
+ output_tokens: int
57
+ start_time: datetime
58
+ end_time: datetime
59
+
60
+ @computed_field
61
+ @property
62
+ def duration_ms(self) -> int:
63
+ return int((self.end_time - self.start_time).total_seconds() * 1000)
64
+
65
+ @computed_field
66
+ @property
67
+ def input_cost(self) -> float:
68
+ return (
69
+ self.INPUT_CENTS_PER_MILLION_TOKENS[self.model]
70
+ * self.input_tokens
71
+ / 1_000_000
72
+ )
73
+
74
+ @computed_field
75
+ @property
76
+ def output_cost(self) -> float:
77
+ return (
78
+ self.OUTPUT_CENTS_PER_MILLION_TOKENS[self.model]
79
+ * self.output_tokens
80
+ / 1_000_000
81
+ )
82
+
83
+ @computed_field
84
+ @property
85
+ def total_cost(self) -> float:
86
+ return self.input_cost + self.output_cost
87
+
88
+ def minimized_dump(self) -> AgentResponseDump:
89
+ if self.system is None:
90
+ minimized_system_message = ""
91
+ else:
92
+ minimized_system_message = self.system.render()
93
+ for part in minimized_system_message["content"]:
94
+ if part["type"] == "image_url":
95
+ part["image_url"] = "__MEDIA__"
96
+ minimized_system_message = json.dumps(minimized_system_message)
97
+
98
+ minimized_input_messages = [message.render() for message in self.input_messages]
99
+ for message in minimized_input_messages:
100
+ for part in message["content"]:
101
+ if part["type"] == "image_url":
102
+ part["image_url"] = "__MEDIA__"
103
+ minimized_input_messages = [
104
+ json.dumps(message) for message in minimized_input_messages
105
+ ]
106
+
107
+ output_message = (
108
+ self.response.model_dump_json()
109
+ if isinstance(self.response, BaseModel)
110
+ else self.response
111
+ )
112
+
113
+ return {
114
+ "run_id": self.run_id,
115
+ "flow_name": self.flow_name,
116
+ "task_name": self.task_name,
117
+ "model": self.model.value,
118
+ "system_message": minimized_system_message,
119
+ "input_messages": minimized_input_messages,
120
+ "output_message": output_message,
121
+ "input_tokens": self.input_tokens,
122
+ "output_tokens": self.output_tokens,
123
+ "input_cost": self.input_cost,
124
+ "output_cost": self.output_cost,
125
+ "total_cost": self.total_cost,
126
+ "start_time": self.start_time,
127
+ "end_time": self.end_time,
128
+ "duration_ms": self.duration_ms,
129
+ }
130
+
131
+
132
+ class IAgentLogger(Protocol):
133
+ async def __call__(self, *, response: AgentResponse[Any]) -> None: ...
134
+
135
+
136
+ class Agent:
137
+ def __init__(
138
+ self,
139
+ *,
140
+ flow_name: str,
141
+ run_id: str,
142
+ logger: IAgentLogger | None = None,
143
+ ) -> None:
144
+ self.flow_name = flow_name
145
+ self.run_id = run_id
146
+ self.logger = logger
147
+
148
+ async def __call__[R: Result](
149
+ self,
150
+ *,
151
+ messages: list[UserMessage | AssistantMessage],
152
+ model: GeminiModel,
153
+ task_name: str,
154
+ response_model: type[R] = TextResult,
155
+ system: SystemMessage | None = None,
156
+ ) -> R:
157
+ start_time = datetime.now()
158
+ rendered_messages = [message.render() for message in messages]
159
+ if system is not None:
160
+ rendered_messages.insert(0, system.render())
161
+
162
+ if response_model is TextResult:
163
+ response = await acompletion(model=model.value, messages=rendered_messages)
164
+ parsed_response = response_model.model_validate(
165
+ {"text": response.choices[0].message.content}
166
+ )
167
+ else:
168
+ response = await acompletion(
169
+ model=model.value,
170
+ messages=rendered_messages,
171
+ response_format={
172
+ "type": "json_object",
173
+ "response_schema": response_model.model_json_schema(),
174
+ "enforce_validation": True,
175
+ },
176
+ )
177
+ parsed_response = response_model.model_validate_json(
178
+ response.choices[0].message.content
179
+ )
180
+
181
+ end_time = datetime.now()
182
+ agent_response = AgentResponse(
183
+ response=parsed_response,
184
+ run_id=self.run_id,
185
+ flow_name=self.flow_name,
186
+ task_name=task_name,
187
+ model=model,
188
+ system=system,
189
+ input_messages=messages,
190
+ input_tokens=response.usage.prompt_tokens,
191
+ output_tokens=response.usage.completion_tokens,
192
+ start_time=start_time,
193
+ end_time=end_time,
194
+ )
195
+
196
+ if self.logger is not None:
197
+ await self.logger(response=agent_response)
198
+ else:
199
+ logging.info(agent_response.model_dump())
200
+
201
+ return parsed_response
@@ -0,0 +1,41 @@
1
+ from pydantic import BaseModel
2
+
3
+ from goose._internal.result import Result
4
+ from goose._internal.types.agent import (
5
+ AssistantMessage,
6
+ LLMMessage,
7
+ SystemMessage,
8
+ UserMessage,
9
+ )
10
+
11
+
12
+ class Conversation[R: Result](BaseModel):
13
+ user_messages: list[UserMessage]
14
+ result_messages: list[R]
15
+ context: SystemMessage | None = None
16
+
17
+ @property
18
+ def awaiting_response(self) -> bool:
19
+ return len(self.user_messages) == len(self.result_messages)
20
+
21
+ def render(self) -> list[LLMMessage]:
22
+ messages: list[LLMMessage] = []
23
+ if self.context is not None:
24
+ messages.append(self.context.render())
25
+
26
+ for message_index in range(len(self.user_messages)):
27
+ messages.append(
28
+ AssistantMessage(
29
+ text=self.result_messages[message_index].model_dump_json()
30
+ ).render()
31
+ )
32
+ messages.append(self.user_messages[message_index].render())
33
+
34
+ if len(self.result_messages) > len(self.user_messages):
35
+ messages.append(
36
+ AssistantMessage(
37
+ text=self.result_messages[-1].model_dump_json()
38
+ ).render()
39
+ )
40
+
41
+ return messages
@@ -0,0 +1,106 @@
1
+ from contextlib import asynccontextmanager
2
+ from types import CodeType
3
+ from typing import AsyncIterator, Awaitable, Callable, Protocol, overload
4
+
5
+ from goose._internal.agent import Agent, IAgentLogger
6
+ from goose._internal.conversation import Conversation
7
+ from goose._internal.result import Result
8
+ from goose._internal.state import FlowRun, get_current_flow_run, set_current_flow_run
9
+ from goose._internal.store import IFlowRunStore, InMemoryFlowRunStore
10
+ from goose.errors import Honk
11
+
12
+
13
+ class IAdapter[ResultT: Result](Protocol):
14
+ __code__: CodeType
15
+
16
+ async def __call__(
17
+ self, *, conversation: Conversation[ResultT], agent: Agent
18
+ ) -> ResultT: ...
19
+
20
+
21
+ class Flow[**P]:
22
+ def __init__(
23
+ self,
24
+ fn: Callable[P, Awaitable[None]],
25
+ /,
26
+ *,
27
+ name: str | None = None,
28
+ store: IFlowRunStore | None = None,
29
+ agent_logger: IAgentLogger | None = None,
30
+ ) -> None:
31
+ self._fn = fn
32
+ self._name = name
33
+ self._agent_logger = agent_logger
34
+ self._store = store or InMemoryFlowRunStore(flow_name=self.name)
35
+
36
+ @property
37
+ def name(self) -> str:
38
+ return self._name or self._fn.__name__
39
+
40
+ @property
41
+ def current_run(self) -> FlowRun:
42
+ run = get_current_flow_run()
43
+ if run is None:
44
+ raise Honk("No current flow run")
45
+ return run
46
+
47
+ @asynccontextmanager
48
+ async def start_run(self, *, run_id: str) -> AsyncIterator[FlowRun]:
49
+ existing_run = await self._store.get(run_id=run_id)
50
+ if existing_run is None:
51
+ run = FlowRun()
52
+ else:
53
+ run = existing_run
54
+
55
+ old_run = get_current_flow_run()
56
+ set_current_flow_run(run)
57
+
58
+ run.start(flow_name=self.name, run_id=run_id, agent_logger=self._agent_logger)
59
+ yield run
60
+ await self._store.save(run=run)
61
+ run.end()
62
+
63
+ set_current_flow_run(old_run)
64
+
65
+ async def generate(self, *args: P.args, **kwargs: P.kwargs) -> None:
66
+ flow_run = get_current_flow_run()
67
+ if flow_run is None:
68
+ raise Honk("No current flow run")
69
+
70
+ flow_run.set_flow_inputs(*args, **kwargs)
71
+ await self._fn(*args, **kwargs)
72
+
73
+ async def regenerate(self) -> None:
74
+ flow_run = get_current_flow_run()
75
+ if flow_run is None:
76
+ raise Honk("No current flow run")
77
+
78
+ flow_args, flow_kwargs = flow_run.flow_inputs
79
+ await self._fn(*flow_args, **flow_kwargs)
80
+
81
+
82
+ @overload
83
+ def flow[**P](fn: Callable[P, Awaitable[None]], /) -> Flow[P]: ...
84
+ @overload
85
+ def flow[**P](
86
+ *,
87
+ name: str | None = None,
88
+ store: IFlowRunStore | None = None,
89
+ agent_logger: IAgentLogger | None = None,
90
+ ) -> Callable[[Callable[P, Awaitable[None]]], Flow[P]]: ...
91
+ def flow[**P](
92
+ fn: Callable[P, Awaitable[None]] | None = None,
93
+ /,
94
+ *,
95
+ name: str | None = None,
96
+ store: IFlowRunStore | None = None,
97
+ agent_logger: IAgentLogger | None = None,
98
+ ) -> Flow[P] | Callable[[Callable[P, Awaitable[None]]], Flow[P]]:
99
+ if fn is None:
100
+
101
+ def decorator(fn: Callable[P, Awaitable[None]]) -> Flow[P]:
102
+ return Flow(fn, name=name, store=store, agent_logger=agent_logger)
103
+
104
+ return decorator
105
+
106
+ return Flow(fn, name=name, store=store, agent_logger=agent_logger)
@@ -0,0 +1,190 @@
1
+ from contextvars import ContextVar
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING, Any, Self
4
+
5
+ from pydantic import BaseModel
6
+
7
+ from goose._internal.agent import (
8
+ Agent,
9
+ IAgentLogger,
10
+ SystemMessage,
11
+ UserMessage,
12
+ )
13
+ from goose._internal.conversation import Conversation
14
+ from goose._internal.result import Result
15
+ from goose.errors import Honk
16
+
17
+ if TYPE_CHECKING:
18
+ from goose._internal.task import Task
19
+
20
+
21
+ @dataclass
22
+ class FlowRunState:
23
+ node_states: dict[tuple[str, int], str]
24
+ flow_args: tuple[Any, ...]
25
+ flow_kwargs: dict[str, Any]
26
+
27
+
28
+ class NodeState[ResultT: Result](BaseModel):
29
+ task_name: str
30
+ index: int
31
+ conversation: Conversation[ResultT]
32
+ last_hash: int
33
+
34
+ @property
35
+ def result(self) -> ResultT:
36
+ if len(self.conversation.result_messages) == 0:
37
+ raise Honk("Node awaiting response, has no result")
38
+
39
+ return self.conversation.result_messages[-1]
40
+
41
+ def set_context(self, *, context: SystemMessage) -> Self:
42
+ self.conversation.context = context
43
+ return self
44
+
45
+ def add_result(
46
+ self,
47
+ *,
48
+ result: ResultT,
49
+ new_hash: int | None = None,
50
+ overwrite: bool = False,
51
+ ) -> Self:
52
+ if overwrite and len(self.conversation.result_messages) > 0:
53
+ self.conversation.result_messages[-1] = result
54
+ else:
55
+ self.conversation.result_messages.append(result)
56
+ if new_hash is not None:
57
+ self.last_hash = new_hash
58
+ return self
59
+
60
+ def add_user_message(self, *, message: UserMessage) -> Self:
61
+ self.conversation.user_messages.append(message)
62
+ return self
63
+
64
+
65
+ class FlowRun:
66
+ def __init__(self) -> None:
67
+ self._node_states: dict[tuple[str, int], str] = {}
68
+ self._last_requested_indices: dict[str, int] = {}
69
+ self._flow_name = ""
70
+ self._id = ""
71
+ self._agent: Agent | None = None
72
+ self._flow_args: tuple[Any, ...] | None = None
73
+ self._flow_kwargs: dict[str, Any] | None = None
74
+
75
+ @property
76
+ def flow_name(self) -> str:
77
+ return self._flow_name
78
+
79
+ @property
80
+ def id(self) -> str:
81
+ return self._id
82
+
83
+ @property
84
+ def agent(self) -> Agent:
85
+ if self._agent is None:
86
+ raise Honk("Agent is only accessible once a run is started")
87
+ return self._agent
88
+
89
+ @property
90
+ def flow_inputs(self) -> tuple[tuple[Any, ...], dict[str, Any]]:
91
+ if self._flow_args is None or self._flow_kwargs is None:
92
+ raise Honk("This Flow run has not been executed before")
93
+
94
+ return self._flow_args, self._flow_kwargs
95
+
96
+ def get_all[R: Result](self, *, task: "Task[Any, R]") -> list[NodeState[R]]:
97
+ matching_nodes: list[NodeState[R]] = []
98
+ for key, node_state in self._node_states.items():
99
+ if key[0] == task.name:
100
+ matching_nodes.append(
101
+ NodeState[task.result_type].model_validate_json(node_state)
102
+ )
103
+ return sorted(matching_nodes, key=lambda node: node.index)
104
+
105
+ def get[R: Result](self, *, task: "Task[Any, R]", index: int = 0) -> NodeState[R]:
106
+ if (
107
+ existing_node_state := self._node_states.get((task.name, index))
108
+ ) is not None:
109
+ return NodeState[task.result_type].model_validate_json(existing_node_state)
110
+ else:
111
+ return NodeState[task.result_type](
112
+ task_name=task.name,
113
+ index=index,
114
+ conversation=Conversation[task.result_type](
115
+ user_messages=[], result_messages=[]
116
+ ),
117
+ last_hash=0,
118
+ )
119
+
120
+ def set_flow_inputs(self, *args: Any, **kwargs: Any) -> None:
121
+ self._flow_args = args
122
+ self._flow_kwargs = kwargs
123
+
124
+ def add_node_state(self, node_state: NodeState[Any], /) -> None:
125
+ key = (node_state.task_name, node_state.index)
126
+ self._node_states[key] = node_state.model_dump_json()
127
+
128
+ def get_next[R: Result](self, *, task: "Task[Any, R]") -> NodeState[R]:
129
+ if task.name not in self._last_requested_indices:
130
+ self._last_requested_indices[task.name] = 0
131
+ else:
132
+ self._last_requested_indices[task.name] += 1
133
+
134
+ return self.get(task=task, index=self._last_requested_indices[task.name])
135
+
136
+ def start(
137
+ self,
138
+ *,
139
+ flow_name: str,
140
+ run_id: str,
141
+ agent_logger: IAgentLogger | None = None,
142
+ ) -> None:
143
+ self._last_requested_indices = {}
144
+ self._flow_name = flow_name
145
+ self._id = run_id
146
+ self._agent = Agent(
147
+ flow_name=self.flow_name, run_id=self.id, logger=agent_logger
148
+ )
149
+
150
+ def end(self) -> None:
151
+ self._last_requested_indices = {}
152
+ self._flow_name = ""
153
+ self._id = ""
154
+ self._agent = None
155
+
156
+ def clear_node(self, *, task: "Task[Any, Result]", index: int) -> None:
157
+ key = (task.name, index)
158
+ if key in self._node_states:
159
+ del self._node_states[key]
160
+
161
+ def dump(self) -> FlowRunState:
162
+ flow_args, flow_kwargs = self.flow_inputs
163
+
164
+ return FlowRunState(
165
+ node_states=self._node_states,
166
+ flow_args=flow_args,
167
+ flow_kwargs=flow_kwargs,
168
+ )
169
+
170
+ @classmethod
171
+ def load(cls, flow_run_state: FlowRunState, /) -> Self:
172
+ flow_run = cls()
173
+ flow_run._node_states = flow_run_state.node_states
174
+ flow_run._flow_args = flow_run_state.flow_args
175
+ flow_run._flow_kwargs = flow_run_state.flow_kwargs
176
+
177
+ return flow_run
178
+
179
+
180
+ _current_flow_run: ContextVar[FlowRun | None] = ContextVar(
181
+ "current_flow_run", default=None
182
+ )
183
+
184
+
185
+ def get_current_flow_run() -> FlowRun | None:
186
+ return _current_flow_run.get()
187
+
188
+
189
+ def set_current_flow_run(flow_run: FlowRun | None) -> None:
190
+ _current_flow_run.set(flow_run)
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, Protocol
3
+ from typing import Protocol
4
4
 
5
- if TYPE_CHECKING:
6
- from goose.flow import FlowRun
5
+ from goose._internal.flow import FlowRun
6
+ from goose._internal.state import FlowRunState
7
7
 
8
8
 
9
9
  class IFlowRunStore(Protocol):
@@ -16,13 +16,15 @@ class IFlowRunStore(Protocol):
16
16
  class InMemoryFlowRunStore(IFlowRunStore):
17
17
  def __init__(self, *, flow_name: str) -> None:
18
18
  self._flow_name = flow_name
19
- self._runs: dict[str, FlowRun] = {}
19
+ self._runs: dict[str, FlowRunState] = {}
20
20
 
21
21
  async def get(self, *, run_id: str) -> FlowRun | None:
22
- return self._runs.get(run_id)
22
+ state = self._runs.get(run_id)
23
+ if state is not None:
24
+ return FlowRun.load(state)
23
25
 
24
26
  async def save(self, *, run: FlowRun) -> None:
25
- self._runs[run.id] = run
27
+ self._runs[run.id] = run.dump()
26
28
 
27
29
  async def delete(self, *, run_id: str) -> None:
28
30
  self._runs.pop(run_id, None)