goose-py 0.5.1__tar.gz → 0.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: goose-py
3
- Version: 0.5.1
3
+ Version: 0.6.0
4
4
  Summary: A tool for AI workflows based on human-computer collaboration and structured output.
5
5
  Home-page: https://github.com/chelle-ai/goose
6
6
  Keywords: ai,yaml,configuration,llm
@@ -0,0 +1,5 @@
1
+ from goose._agent import Agent, AgentResponse, IAgentLogger
2
+ from goose._flow import flow
3
+ from goose._result import Result, TextResult
4
+ from goose._store import IFlowRunStore
5
+ from goose._task import task
@@ -1,100 +1,13 @@
1
1
  import json
2
2
  import logging
3
3
  from datetime import datetime
4
- from enum import StrEnum
5
- from typing import Any, ClassVar, Literal, NotRequired, Protocol, TypedDict
4
+ from typing import Any, ClassVar, Protocol, TypedDict
6
5
 
7
6
  from litellm import acompletion
8
7
  from pydantic import BaseModel, computed_field
9
- from goose.result import Result, TextResult
10
8
 
11
-
12
- class GeminiModel(StrEnum):
13
- PRO = "vertex_ai/gemini-1.5-pro"
14
- FLASH = "vertex_ai/gemini-1.5-flash"
15
- FLASH_8B = "vertex_ai/gemini-1.5-flash-8b"
16
-
17
-
18
- class UserMediaContentType(StrEnum):
19
- # images
20
- JPEG = "image/jpeg"
21
- PNG = "image/png"
22
- WEBP = "image/webp"
23
-
24
- # audio
25
- MP3 = "audio/mp3"
26
- WAV = "audio/wav"
27
-
28
- # files
29
- PDF = "application/pdf"
30
-
31
-
32
- class LLMTextMessagePart(TypedDict):
33
- type: Literal["text"]
34
- text: str
35
-
36
-
37
- class LLMMediaMessagePart(TypedDict):
38
- type: Literal["image_url"]
39
- image_url: str
40
-
41
-
42
- class CacheControl(TypedDict):
43
- type: Literal["ephemeral"]
44
-
45
-
46
- class LLMMessage(TypedDict):
47
- role: Literal["user", "assistant", "system"]
48
- content: list[LLMTextMessagePart | LLMMediaMessagePart]
49
- cache_control: NotRequired[CacheControl]
50
-
51
-
52
- class TextMessagePart(BaseModel):
53
- text: str
54
-
55
- def render(self) -> LLMTextMessagePart:
56
- return {"type": "text", "text": self.text}
57
-
58
-
59
- class MediaMessagePart(BaseModel):
60
- content_type: UserMediaContentType
61
- content: str
62
-
63
- def render(self) -> LLMMediaMessagePart:
64
- return {
65
- "type": "image_url",
66
- "image_url": f"data:{self.content_type};base64,{self.content}",
67
- }
68
-
69
-
70
- class UserMessage(BaseModel):
71
- parts: list[TextMessagePart | MediaMessagePart]
72
-
73
- def render(self) -> LLMMessage:
74
- content: LLMMessage = {
75
- "role": "user",
76
- "content": [part.render() for part in self.parts],
77
- }
78
- if any(isinstance(part, MediaMessagePart) for part in self.parts):
79
- content["cache_control"] = {"type": "ephemeral"}
80
- return content
81
-
82
-
83
- class AssistantMessage(BaseModel):
84
- text: str
85
-
86
- def render(self) -> LLMMessage:
87
- return {"role": "assistant", "content": [{"type": "text", "text": self.text}]}
88
-
89
-
90
- class SystemMessage(BaseModel):
91
- parts: list[TextMessagePart | MediaMessagePart]
92
-
93
- def render(self) -> LLMMessage:
94
- return {
95
- "role": "system",
96
- "content": [part.render() for part in self.parts],
97
- }
9
+ from goose._result import Result, TextResult
10
+ from goose.types.agent import AssistantMessage, GeminiModel, SystemMessage, UserMessage
98
11
 
99
12
 
100
13
  class AgentResponseDump(TypedDict):
@@ -0,0 +1,36 @@
1
+ from pydantic import BaseModel
2
+
3
+ from goose._result import Result
4
+ from goose.types.agent import AssistantMessage, LLMMessage, SystemMessage, UserMessage
5
+
6
+
7
+ class Conversation[R: Result](BaseModel):
8
+ user_messages: list[UserMessage]
9
+ result_messages: list[R]
10
+ context: SystemMessage | None = None
11
+
12
+ @property
13
+ def awaiting_response(self) -> bool:
14
+ return len(self.user_messages) == len(self.result_messages)
15
+
16
+ def render(self) -> list[LLMMessage]:
17
+ messages: list[LLMMessage] = []
18
+ if self.context is not None:
19
+ messages.append(self.context.render())
20
+
21
+ for message_index in range(len(self.user_messages)):
22
+ messages.append(
23
+ AssistantMessage(
24
+ text=self.result_messages[message_index].model_dump_json()
25
+ ).render()
26
+ )
27
+ messages.append(self.user_messages[message_index].render())
28
+
29
+ if len(self.result_messages) > len(self.user_messages):
30
+ messages.append(
31
+ AssistantMessage(
32
+ text=self.result_messages[-1].model_dump_json()
33
+ ).render()
34
+ )
35
+
36
+ return messages
@@ -0,0 +1,106 @@
1
+ from contextlib import asynccontextmanager
2
+ from types import CodeType
3
+ from typing import AsyncIterator, Awaitable, Callable, Protocol, overload
4
+
5
+ from goose._agent import Agent, IAgentLogger
6
+ from goose._conversation import Conversation
7
+ from goose._result import Result
8
+ from goose._state import FlowRun, get_current_flow_run, set_current_flow_run
9
+ from goose._store import IFlowRunStore, InMemoryFlowRunStore
10
+ from goose.errors import Honk
11
+
12
+
13
+ class IAdapter[ResultT: Result](Protocol):
14
+ __code__: CodeType
15
+
16
+ async def __call__(
17
+ self, *, conversation: Conversation[ResultT], agent: Agent
18
+ ) -> ResultT: ...
19
+
20
+
21
+ class Flow[**P]:
22
+ def __init__(
23
+ self,
24
+ fn: Callable[P, Awaitable[None]],
25
+ /,
26
+ *,
27
+ name: str | None = None,
28
+ store: IFlowRunStore | None = None,
29
+ agent_logger: IAgentLogger | None = None,
30
+ ) -> None:
31
+ self._fn = fn
32
+ self._name = name
33
+ self._agent_logger = agent_logger
34
+ self._store = store or InMemoryFlowRunStore(flow_name=self.name)
35
+
36
+ @property
37
+ def name(self) -> str:
38
+ return self._name or self._fn.__name__
39
+
40
+ @property
41
+ def current_run(self) -> FlowRun:
42
+ run = get_current_flow_run()
43
+ if run is None:
44
+ raise Honk("No current flow run")
45
+ return run
46
+
47
+ @asynccontextmanager
48
+ async def start_run(self, *, run_id: str) -> AsyncIterator[FlowRun]:
49
+ existing_run = await self._store.get(run_id=run_id)
50
+ if existing_run is None:
51
+ run = FlowRun()
52
+ else:
53
+ run = existing_run
54
+
55
+ old_run = get_current_flow_run()
56
+ set_current_flow_run(run)
57
+
58
+ run.start(flow_name=self.name, run_id=run_id, agent_logger=self._agent_logger)
59
+ yield run
60
+ await self._store.save(run=run)
61
+ run.end()
62
+
63
+ set_current_flow_run(old_run)
64
+
65
+ async def generate(self, *args: P.args, **kwargs: P.kwargs) -> None:
66
+ flow_run = get_current_flow_run()
67
+ if flow_run is None:
68
+ raise Honk("No current flow run")
69
+
70
+ flow_run.set_flow_inputs(*args, **kwargs)
71
+ await self._fn(*args, **kwargs)
72
+
73
+ async def regenerate(self) -> None:
74
+ flow_run = get_current_flow_run()
75
+ if flow_run is None:
76
+ raise Honk("No current flow run")
77
+
78
+ flow_args, flow_kwargs = flow_run.flow_inputs
79
+ await self._fn(*flow_args, **flow_kwargs)
80
+
81
+
82
+ @overload
83
+ def flow[**P](fn: Callable[P, Awaitable[None]], /) -> Flow[P]: ...
84
+ @overload
85
+ def flow[**P](
86
+ *,
87
+ name: str | None = None,
88
+ store: IFlowRunStore | None = None,
89
+ agent_logger: IAgentLogger | None = None,
90
+ ) -> Callable[[Callable[P, Awaitable[None]]], Flow[P]]: ...
91
+ def flow[**P](
92
+ fn: Callable[P, Awaitable[None]] | None = None,
93
+ /,
94
+ *,
95
+ name: str | None = None,
96
+ store: IFlowRunStore | None = None,
97
+ agent_logger: IAgentLogger | None = None,
98
+ ) -> Flow[P] | Callable[[Callable[P, Awaitable[None]]], Flow[P]]:
99
+ if fn is None:
100
+
101
+ def decorator(fn: Callable[P, Awaitable[None]]) -> Flow[P]:
102
+ return Flow(fn, name=name, store=store, agent_logger=agent_logger)
103
+
104
+ return decorator
105
+
106
+ return Flow(fn, name=name, store=store, agent_logger=agent_logger)
@@ -0,0 +1,190 @@
1
+ from contextvars import ContextVar
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING, Any, Self
4
+
5
+ from pydantic import BaseModel
6
+
7
+ from goose._agent import (
8
+ Agent,
9
+ IAgentLogger,
10
+ SystemMessage,
11
+ UserMessage,
12
+ )
13
+ from goose._conversation import Conversation
14
+ from goose._result import Result
15
+ from goose.errors import Honk
16
+
17
+ if TYPE_CHECKING:
18
+ from goose._task import Task
19
+
20
+
21
+ @dataclass
22
+ class FlowRunState:
23
+ node_states: dict[tuple[str, int], str]
24
+ flow_args: tuple[Any, ...]
25
+ flow_kwargs: dict[str, Any]
26
+
27
+
28
+ class NodeState[ResultT: Result](BaseModel):
29
+ task_name: str
30
+ index: int
31
+ conversation: Conversation[ResultT]
32
+ last_hash: int
33
+
34
+ @property
35
+ def result(self) -> ResultT:
36
+ if len(self.conversation.result_messages) == 0:
37
+ raise Honk("Node awaiting response, has no result")
38
+
39
+ return self.conversation.result_messages[-1]
40
+
41
+ def set_context(self, *, context: SystemMessage) -> Self:
42
+ self.conversation.context = context
43
+ return self
44
+
45
+ def add_result(
46
+ self,
47
+ *,
48
+ result: ResultT,
49
+ new_hash: int | None = None,
50
+ overwrite: bool = False,
51
+ ) -> Self:
52
+ if overwrite and len(self.conversation.result_messages) > 0:
53
+ self.conversation.result_messages[-1] = result
54
+ else:
55
+ self.conversation.result_messages.append(result)
56
+ if new_hash is not None:
57
+ self.last_hash = new_hash
58
+ return self
59
+
60
+ def add_user_message(self, *, message: UserMessage) -> Self:
61
+ self.conversation.user_messages.append(message)
62
+ return self
63
+
64
+
65
+ class FlowRun:
66
+ def __init__(self) -> None:
67
+ self._node_states: dict[tuple[str, int], str] = {}
68
+ self._last_requested_indices: dict[str, int] = {}
69
+ self._flow_name = ""
70
+ self._id = ""
71
+ self._agent: Agent | None = None
72
+ self._flow_args: tuple[Any, ...] | None = None
73
+ self._flow_kwargs: dict[str, Any] | None = None
74
+
75
+ @property
76
+ def flow_name(self) -> str:
77
+ return self._flow_name
78
+
79
+ @property
80
+ def id(self) -> str:
81
+ return self._id
82
+
83
+ @property
84
+ def agent(self) -> Agent:
85
+ if self._agent is None:
86
+ raise Honk("Agent is only accessible once a run is started")
87
+ return self._agent
88
+
89
+ @property
90
+ def flow_inputs(self) -> tuple[tuple[Any, ...], dict[str, Any]]:
91
+ if self._flow_args is None or self._flow_kwargs is None:
92
+ raise Honk("This Flow run has not been executed before")
93
+
94
+ return self._flow_args, self._flow_kwargs
95
+
96
+ def get_all[R: Result](self, *, task: "Task[Any, R]") -> list[NodeState[R]]:
97
+ matching_nodes: list[NodeState[R]] = []
98
+ for key, node_state in self._node_states.items():
99
+ if key[0] == task.name:
100
+ matching_nodes.append(
101
+ NodeState[task.result_type].model_validate_json(node_state)
102
+ )
103
+ return sorted(matching_nodes, key=lambda node: node.index)
104
+
105
+ def get[R: Result](self, *, task: "Task[Any, R]", index: int = 0) -> NodeState[R]:
106
+ if (
107
+ existing_node_state := self._node_states.get((task.name, index))
108
+ ) is not None:
109
+ return NodeState[task.result_type].model_validate_json(existing_node_state)
110
+ else:
111
+ return NodeState[task.result_type](
112
+ task_name=task.name,
113
+ index=index,
114
+ conversation=Conversation[task.result_type](
115
+ user_messages=[], result_messages=[]
116
+ ),
117
+ last_hash=0,
118
+ )
119
+
120
+ def set_flow_inputs(self, *args: Any, **kwargs: Any) -> None:
121
+ self._flow_args = args
122
+ self._flow_kwargs = kwargs
123
+
124
+ def add_node_state(self, node_state: NodeState[Any], /) -> None:
125
+ key = (node_state.task_name, node_state.index)
126
+ self._node_states[key] = node_state.model_dump_json()
127
+
128
+ def get_next[R: Result](self, *, task: "Task[Any, R]") -> NodeState[R]:
129
+ if task.name not in self._last_requested_indices:
130
+ self._last_requested_indices[task.name] = 0
131
+ else:
132
+ self._last_requested_indices[task.name] += 1
133
+
134
+ return self.get(task=task, index=self._last_requested_indices[task.name])
135
+
136
+ def start(
137
+ self,
138
+ *,
139
+ flow_name: str,
140
+ run_id: str,
141
+ agent_logger: IAgentLogger | None = None,
142
+ ) -> None:
143
+ self._last_requested_indices = {}
144
+ self._flow_name = flow_name
145
+ self._id = run_id
146
+ self._agent = Agent(
147
+ flow_name=self.flow_name, run_id=self.id, logger=agent_logger
148
+ )
149
+
150
+ def end(self) -> None:
151
+ self._last_requested_indices = {}
152
+ self._flow_name = ""
153
+ self._id = ""
154
+ self._agent = None
155
+
156
+ def clear_node(self, *, task: "Task[Any, Result]", index: int) -> None:
157
+ key = (task.name, index)
158
+ if key in self._node_states:
159
+ del self._node_states[key]
160
+
161
+ def dump(self) -> FlowRunState:
162
+ flow_args, flow_kwargs = self.flow_inputs
163
+
164
+ return FlowRunState(
165
+ node_states=self._node_states,
166
+ flow_args=flow_args,
167
+ flow_kwargs=flow_kwargs,
168
+ )
169
+
170
+ @classmethod
171
+ def load(cls, flow_run_state: FlowRunState, /) -> Self:
172
+ flow_run = cls()
173
+ flow_run._node_states = flow_run_state.node_states
174
+ flow_run._flow_args = flow_run_state.flow_args
175
+ flow_run._flow_kwargs = flow_run_state.flow_kwargs
176
+
177
+ return flow_run
178
+
179
+
180
+ _current_flow_run: ContextVar[FlowRun | None] = ContextVar(
181
+ "current_flow_run", default=None
182
+ )
183
+
184
+
185
+ def get_current_flow_run() -> FlowRun | None:
186
+ return _current_flow_run.get()
187
+
188
+
189
+ def set_current_flow_run(flow_run: FlowRun | None) -> None:
190
+ _current_flow_run.set(flow_run)
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, Protocol
3
+ from typing import Protocol
4
4
 
5
- if TYPE_CHECKING:
6
- from goose.flow import FlowRun
5
+ from goose._flow import FlowRun
6
+ from goose._state import FlowRunState
7
7
 
8
8
 
9
9
  class IFlowRunStore(Protocol):
@@ -16,13 +16,15 @@ class IFlowRunStore(Protocol):
16
16
  class InMemoryFlowRunStore(IFlowRunStore):
17
17
  def __init__(self, *, flow_name: str) -> None:
18
18
  self._flow_name = flow_name
19
- self._runs: dict[str, FlowRun] = {}
19
+ self._runs: dict[str, FlowRunState] = {}
20
20
 
21
21
  async def get(self, *, run_id: str) -> FlowRun | None:
22
- return self._runs.get(run_id)
22
+ state = self._runs.get(run_id)
23
+ if state is not None:
24
+ return FlowRun.load(state)
23
25
 
24
26
  async def save(self, *, run: FlowRun) -> None:
25
- self._runs[run.id] = run
27
+ self._runs[run.id] = run.dump()
26
28
 
27
29
  async def delete(self, *, run_id: str) -> None:
28
30
  self._runs.pop(run_id, None)
@@ -0,0 +1,136 @@
1
+ from typing import Awaitable, Callable, overload
2
+
3
+ from goose._agent import Agent, GeminiModel, SystemMessage, UserMessage
4
+ from goose._conversation import Conversation
5
+ from goose._result import Result, TextResult
6
+ from goose._state import FlowRun, NodeState, get_current_flow_run
7
+ from goose.errors import Honk
8
+ from goose.types.agent import AssistantMessage
9
+
10
+
11
+ class Task[**P, R: Result]:
12
+ def __init__(
13
+ self,
14
+ generator: Callable[P, Awaitable[R]],
15
+ /,
16
+ *,
17
+ retries: int = 0,
18
+ adapter_model: GeminiModel = GeminiModel.FLASH,
19
+ ) -> None:
20
+ self._generator = generator
21
+ self._retries = retries
22
+ self._adapter_model = adapter_model
23
+ self._adapter_model = adapter_model
24
+
25
+ @property
26
+ def result_type(self) -> type[R]:
27
+ result_type = self._generator.__annotations__.get("return")
28
+ if result_type is None:
29
+ raise Honk(f"Task {self.name} has no return type annotation")
30
+ return result_type
31
+
32
+ @property
33
+ def name(self) -> str:
34
+ return self._generator.__name__
35
+
36
+ async def generate(
37
+ self, state: NodeState[R], *args: P.args, **kwargs: P.kwargs
38
+ ) -> R:
39
+ state_hash = self.__hash_task_call(*args, **kwargs)
40
+ if state_hash != state.last_hash:
41
+ result = await self._generator(*args, **kwargs)
42
+ state.add_result(result=result, new_hash=state_hash, overwrite=True)
43
+ return result
44
+ else:
45
+ return state.result
46
+
47
+ async def jam(
48
+ self,
49
+ *,
50
+ user_message: UserMessage,
51
+ context: SystemMessage | None = None,
52
+ index: int = 0,
53
+ ) -> R:
54
+ flow_run = self.__get_current_flow_run()
55
+ node_state = flow_run.get(task=self, index=index)
56
+
57
+ if context is not None:
58
+ node_state.set_context(context=context)
59
+ node_state.add_user_message(message=user_message)
60
+
61
+ result = await self.__adapt(
62
+ conversation=node_state.conversation, agent=flow_run.agent
63
+ )
64
+ node_state.add_result(result=result)
65
+ flow_run.add_node_state(node_state)
66
+
67
+ return result
68
+
69
+ async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
70
+ flow_run = self.__get_current_flow_run()
71
+ node_state = flow_run.get_next(task=self)
72
+ result = await self.generate(node_state, *args, **kwargs)
73
+ flow_run.add_node_state(node_state)
74
+ return result
75
+
76
+ async def __adapt(self, *, conversation: Conversation[R], agent: Agent) -> R:
77
+ messages: list[UserMessage | AssistantMessage] = []
78
+ for message_index in range(len(conversation.user_messages)):
79
+ user_message = conversation.user_messages[message_index]
80
+ result = conversation.result_messages[message_index]
81
+
82
+ if isinstance(result, TextResult):
83
+ assistant_text = result.text
84
+ else:
85
+ assistant_text = result.model_dump_json()
86
+ assistant_message = AssistantMessage(text=assistant_text)
87
+ messages.append(assistant_message)
88
+ messages.append(user_message)
89
+
90
+ return await agent(
91
+ messages=messages,
92
+ model=self._adapter_model,
93
+ task_name=f"adapt--{self.name}",
94
+ system=conversation.context,
95
+ response_model=self.result_type,
96
+ )
97
+
98
+ def __hash_task_call(self, *args: P.args, **kwargs: P.kwargs) -> int:
99
+ try:
100
+ to_hash = str(
101
+ tuple(args)
102
+ + tuple(kwargs.values())
103
+ + (self._generator.__code__, self._adapter_model)
104
+ )
105
+ return hash(to_hash)
106
+ except TypeError:
107
+ raise Honk(f"Unhashable argument to task {self.name}: {args} {kwargs}")
108
+
109
+ def __get_current_flow_run(self) -> FlowRun:
110
+ run = get_current_flow_run()
111
+ if run is None:
112
+ raise Honk("No current flow run")
113
+ return run
114
+
115
+
116
+ @overload
117
+ def task[**P, R: Result](generator: Callable[P, Awaitable[R]], /) -> Task[P, R]: ...
118
+ @overload
119
+ def task[**P, R: Result](
120
+ *, retries: int = 0, adapter_model: GeminiModel = GeminiModel.FLASH
121
+ ) -> Callable[[Callable[P, Awaitable[R]]], Task[P, R]]: ...
122
+ def task[**P, R: Result](
123
+ generator: Callable[P, Awaitable[R]] | None = None,
124
+ /,
125
+ *,
126
+ retries: int = 0,
127
+ adapter_model: GeminiModel = GeminiModel.FLASH,
128
+ ) -> Task[P, R] | Callable[[Callable[P, Awaitable[R]]], Task[P, R]]:
129
+ if generator is None:
130
+
131
+ def decorator(fn: Callable[P, Awaitable[R]]) -> Task[P, R]:
132
+ return Task(fn, retries=retries, adapter_model=adapter_model)
133
+
134
+ return decorator
135
+
136
+ return Task(generator, retries=retries, adapter_model=adapter_model)
@@ -0,0 +1,92 @@
1
+ from enum import StrEnum
2
+ from typing import Literal, NotRequired, TypedDict
3
+
4
+ from pydantic import BaseModel
5
+
6
+
7
+ class GeminiModel(StrEnum):
8
+ PRO = "vertex_ai/gemini-1.5-pro"
9
+ FLASH = "vertex_ai/gemini-1.5-flash"
10
+ FLASH_8B = "vertex_ai/gemini-1.5-flash-8b"
11
+
12
+
13
+ class UserMediaContentType(StrEnum):
14
+ # images
15
+ JPEG = "image/jpeg"
16
+ PNG = "image/png"
17
+ WEBP = "image/webp"
18
+
19
+ # audio
20
+ MP3 = "audio/mp3"
21
+ WAV = "audio/wav"
22
+
23
+ # files
24
+ PDF = "application/pdf"
25
+
26
+
27
+ class LLMTextMessagePart(TypedDict):
28
+ type: Literal["text"]
29
+ text: str
30
+
31
+
32
+ class LLMMediaMessagePart(TypedDict):
33
+ type: Literal["image_url"]
34
+ image_url: str
35
+
36
+
37
+ class CacheControl(TypedDict):
38
+ type: Literal["ephemeral"]
39
+
40
+
41
+ class LLMMessage(TypedDict):
42
+ role: Literal["user", "assistant", "system"]
43
+ content: list[LLMTextMessagePart | LLMMediaMessagePart]
44
+ cache_control: NotRequired[CacheControl]
45
+
46
+
47
+ class TextMessagePart(BaseModel):
48
+ text: str
49
+
50
+ def render(self) -> LLMTextMessagePart:
51
+ return {"type": "text", "text": self.text}
52
+
53
+
54
+ class MediaMessagePart(BaseModel):
55
+ content_type: UserMediaContentType
56
+ content: str
57
+
58
+ def render(self) -> LLMMediaMessagePart:
59
+ return {
60
+ "type": "image_url",
61
+ "image_url": f"data:{self.content_type};base64,{self.content}",
62
+ }
63
+
64
+
65
+ class UserMessage(BaseModel):
66
+ parts: list[TextMessagePart | MediaMessagePart]
67
+
68
+ def render(self) -> LLMMessage:
69
+ content: LLMMessage = {
70
+ "role": "user",
71
+ "content": [part.render() for part in self.parts],
72
+ }
73
+ if any(isinstance(part, MediaMessagePart) for part in self.parts):
74
+ content["cache_control"] = {"type": "ephemeral"}
75
+ return content
76
+
77
+
78
+ class AssistantMessage(BaseModel):
79
+ text: str
80
+
81
+ def render(self) -> LLMMessage:
82
+ return {"role": "assistant", "content": [{"type": "text", "text": self.text}]}
83
+
84
+
85
+ class SystemMessage(BaseModel):
86
+ parts: list[TextMessagePart | MediaMessagePart]
87
+
88
+ def render(self) -> LLMMessage:
89
+ return {
90
+ "role": "system",
91
+ "content": [part.render() for part in self.parts],
92
+ }
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "goose-py"
3
- version = "0.5.1"
3
+ version = "0.6.0"
4
4
  description = "A tool for AI workflows based on human-computer collaboration and structured output."
5
5
  authors = [
6
6
  "Nash Taylor <nash@chelle.ai>",
@@ -53,6 +53,10 @@ reportImportCycles = false
53
53
  reportUnknownMemberType = false
54
54
  reportUnknownVariableType = false
55
55
  stubPath = ".stubs"
56
+ exclude = [
57
+ "goose/__init__.py",
58
+ ".venv",
59
+ ]
56
60
 
57
61
  [tool.pytest.ini_options]
58
62
  filterwarnings = [
@@ -1,458 +0,0 @@
1
- import json
2
- from contextlib import asynccontextmanager
3
- from contextvars import ContextVar
4
- from types import CodeType
5
- from typing import (
6
- Any,
7
- AsyncIterator,
8
- Awaitable,
9
- Callable,
10
- NewType,
11
- Protocol,
12
- Self,
13
- overload,
14
- )
15
-
16
- from pydantic import BaseModel
17
-
18
- from goose.agent import (
19
- Agent,
20
- AssistantMessage,
21
- GeminiModel,
22
- IAgentLogger,
23
- LLMMessage,
24
- SystemMessage,
25
- UserMessage,
26
- )
27
- from goose.errors import Honk
28
- from goose.result import Result, TextResult
29
- from goose.store import IFlowRunStore, InMemoryFlowRunStore
30
-
31
- SerializedFlowRun = NewType("SerializedFlowRun", str)
32
-
33
-
34
- class Conversation[R: Result](BaseModel):
35
- user_messages: list[UserMessage]
36
- result_messages: list[R]
37
- context: SystemMessage | None = None
38
-
39
- @property
40
- def awaiting_response(self) -> bool:
41
- return len(self.user_messages) == len(self.result_messages)
42
-
43
- def render(self) -> list[LLMMessage]:
44
- messages: list[LLMMessage] = []
45
- if self.context is not None:
46
- messages.append(self.context.render())
47
-
48
- for message_index in range(len(self.user_messages)):
49
- messages.append(
50
- AssistantMessage(
51
- text=self.result_messages[message_index].model_dump_json()
52
- ).render()
53
- )
54
- messages.append(self.user_messages[message_index].render())
55
-
56
- if len(self.result_messages) > len(self.user_messages):
57
- messages.append(
58
- AssistantMessage(
59
- text=self.result_messages[-1].model_dump_json()
60
- ).render()
61
- )
62
-
63
- return messages
64
-
65
-
66
- class IAdapter[ResultT: Result](Protocol):
67
- __code__: CodeType
68
-
69
- async def __call__(
70
- self, *, conversation: Conversation[ResultT], agent: Agent
71
- ) -> ResultT: ...
72
-
73
-
74
- class NodeState[ResultT: Result](BaseModel):
75
- task_name: str
76
- index: int
77
- conversation: Conversation[ResultT]
78
- last_hash: int
79
-
80
- @property
81
- def result(self) -> ResultT:
82
- if len(self.conversation.result_messages) == 0:
83
- raise Honk("Node awaiting response, has no result")
84
-
85
- return self.conversation.result_messages[-1]
86
-
87
- def set_context(self, *, context: SystemMessage) -> Self:
88
- self.conversation.context = context
89
- return self
90
-
91
- def add_result(
92
- self,
93
- *,
94
- result: ResultT,
95
- new_hash: int | None = None,
96
- overwrite: bool = False,
97
- ) -> Self:
98
- if overwrite and len(self.conversation.result_messages) > 0:
99
- self.conversation.result_messages[-1] = result
100
- else:
101
- self.conversation.result_messages.append(result)
102
- if new_hash is not None:
103
- self.last_hash = new_hash
104
- return self
105
-
106
- def add_user_message(self, *, message: UserMessage) -> Self:
107
- self.conversation.user_messages.append(message)
108
- return self
109
-
110
-
111
- class FlowRun:
112
- def __init__(self) -> None:
113
- self._node_states: dict[tuple[str, int], str] = {}
114
- self._last_requested_indices: dict[str, int] = {}
115
- self._flow_name = ""
116
- self._id = ""
117
- self._agent: Agent | None = None
118
- self._flow_args: tuple[Any, ...] | None = None
119
- self._flow_kwargs: dict[str, Any] | None = None
120
-
121
- @property
122
- def flow_name(self) -> str:
123
- return self._flow_name
124
-
125
- @property
126
- def id(self) -> str:
127
- return self._id
128
-
129
- @property
130
- def agent(self) -> Agent:
131
- if self._agent is None:
132
- raise Honk("Agent is only accessible once a run is started")
133
- return self._agent
134
-
135
- @property
136
- def flow_inputs(self) -> tuple[tuple[Any, ...], dict[str, Any]]:
137
- if self._flow_args is None or self._flow_kwargs is None:
138
- raise Honk("This Flow run has not been executed before")
139
-
140
- return self._flow_args, self._flow_kwargs
141
-
142
- def get_all[R: Result](self, *, task: "Task[Any, R]") -> list[NodeState[R]]:
143
- matching_nodes: list[NodeState[R]] = []
144
- for key, node_state in self._node_states.items():
145
- if key[0] == task.name:
146
- matching_nodes.append(
147
- NodeState[task.result_type].model_validate_json(node_state)
148
- )
149
- return sorted(matching_nodes, key=lambda node: node.index)
150
-
151
- def get[R: Result](self, *, task: "Task[Any, R]", index: int = 0) -> NodeState[R]:
152
- if (
153
- existing_node_state := self._node_states.get((task.name, index))
154
- ) is not None:
155
- return NodeState[task.result_type].model_validate_json(existing_node_state)
156
- else:
157
- return NodeState[task.result_type](
158
- task_name=task.name,
159
- index=index,
160
- conversation=Conversation[task.result_type](
161
- user_messages=[], result_messages=[]
162
- ),
163
- last_hash=0,
164
- )
165
-
166
- def set_flow_inputs(self, *args: Any, **kwargs: Any) -> None:
167
- self._flow_args = args
168
- self._flow_kwargs = kwargs
169
-
170
- def add_node_state(self, node_state: NodeState[Any], /) -> None:
171
- key = (node_state.task_name, node_state.index)
172
- self._node_states[key] = node_state.model_dump_json()
173
-
174
- def get_next[R: Result](self, *, task: "Task[Any, R]") -> NodeState[R]:
175
- if task.name not in self._last_requested_indices:
176
- self._last_requested_indices[task.name] = 0
177
- else:
178
- self._last_requested_indices[task.name] += 1
179
-
180
- return self.get(task=task, index=self._last_requested_indices[task.name])
181
-
182
- def start(
183
- self,
184
- *,
185
- flow_name: str,
186
- run_id: str,
187
- agent_logger: IAgentLogger | None = None,
188
- ) -> None:
189
- self._last_requested_indices = {}
190
- self._flow_name = flow_name
191
- self._id = run_id
192
- self._agent = Agent(
193
- flow_name=self.flow_name, run_id=self.id, logger=agent_logger
194
- )
195
-
196
- def end(self) -> None:
197
- self._last_requested_indices = {}
198
- self._flow_name = ""
199
- self._id = ""
200
- self._agent = None
201
-
202
- def clear_node(self, *, task: "Task[Any, Result]", index: int) -> None:
203
- key = (task.name, index)
204
- if key in self._node_states:
205
- del self._node_states[key]
206
-
207
- def dump(self) -> SerializedFlowRun:
208
- flow_args, flow_kwargs = self.flow_inputs
209
-
210
- return SerializedFlowRun(
211
- json.dumps(
212
- {
213
- "node_states": {
214
- ":".join([task_name, str(index)]): value
215
- for (task_name, index), value in self._node_states.items()
216
- },
217
- "flow_args": list(flow_args),
218
- "flow_kwargs": flow_kwargs,
219
- }
220
- )
221
- )
222
-
223
- @classmethod
224
- def load(cls, serialized_flow_run: SerializedFlowRun, /) -> Self:
225
- flow_run = cls()
226
- run = json.loads(serialized_flow_run)
227
-
228
- new_node_states: dict[tuple[str, int], str] = {}
229
- for key, node_state in run["node_states"].items():
230
- task_name, index = tuple(key.split(":"))
231
- new_node_states[(task_name, int(index))] = node_state
232
- flow_run._node_states = new_node_states
233
-
234
- flow_run._flow_args = tuple(run["flow_args"])
235
- flow_run._flow_kwargs = run["flow_kwargs"]
236
-
237
- return flow_run
238
-
239
-
240
- _current_flow_run: ContextVar[FlowRun | None] = ContextVar(
241
- "current_flow_run", default=None
242
- )
243
-
244
-
245
- class Flow[**P]:
246
- def __init__(
247
- self,
248
- fn: Callable[P, Awaitable[None]],
249
- /,
250
- *,
251
- name: str | None = None,
252
- store: IFlowRunStore | None = None,
253
- agent_logger: IAgentLogger | None = None,
254
- ) -> None:
255
- self._fn = fn
256
- self._name = name
257
- self._agent_logger = agent_logger
258
- self._store = store or InMemoryFlowRunStore(flow_name=self.name)
259
-
260
- @property
261
- def name(self) -> str:
262
- return self._name or self._fn.__name__
263
-
264
- @property
265
- def current_run(self) -> FlowRun:
266
- run = _current_flow_run.get()
267
- if run is None:
268
- raise Honk("No current flow run")
269
- return run
270
-
271
- @asynccontextmanager
272
- async def start_run(self, *, run_id: str) -> AsyncIterator[FlowRun]:
273
- existing_run = await self._store.get(run_id=run_id)
274
- if existing_run is None:
275
- run = FlowRun()
276
- else:
277
- run = existing_run
278
-
279
- old_run = _current_flow_run.get()
280
- _current_flow_run.set(run)
281
-
282
- run.start(flow_name=self.name, run_id=run_id, agent_logger=self._agent_logger)
283
- yield run
284
- await self._store.save(run=run)
285
- run.end()
286
-
287
- _current_flow_run.set(old_run)
288
-
289
- async def generate(self, *args: P.args, **kwargs: P.kwargs) -> None:
290
- flow_run = _current_flow_run.get()
291
- if flow_run is None:
292
- raise Honk("No current flow run")
293
-
294
- flow_run.set_flow_inputs(*args, **kwargs)
295
- await self._fn(*args, **kwargs)
296
-
297
- async def regenerate(self) -> None:
298
- flow_run = _current_flow_run.get()
299
- if flow_run is None:
300
- raise Honk("No current flow run")
301
-
302
- flow_args, flow_kwargs = flow_run.flow_inputs
303
- await self._fn(*flow_args, **flow_kwargs)
304
-
305
-
306
- class Task[**P, R: Result]:
307
- def __init__(
308
- self,
309
- generator: Callable[P, Awaitable[R]],
310
- /,
311
- *,
312
- retries: int = 0,
313
- adapter_model: GeminiModel = GeminiModel.FLASH,
314
- ) -> None:
315
- self._generator = generator
316
- self._retries = retries
317
- self._adapter_model = adapter_model
318
- self._adapter_model = adapter_model
319
-
320
- @property
321
- def result_type(self) -> type[R]:
322
- result_type = self._generator.__annotations__.get("return")
323
- if result_type is None:
324
- raise Honk(f"Task {self.name} has no return type annotation")
325
- return result_type
326
-
327
- @property
328
- def name(self) -> str:
329
- return self._generator.__name__
330
-
331
- async def generate(
332
- self, state: NodeState[R], *args: P.args, **kwargs: P.kwargs
333
- ) -> R:
334
- state_hash = self.__hash_task_call(*args, **kwargs)
335
- if state_hash != state.last_hash:
336
- result = await self._generator(*args, **kwargs)
337
- state.add_result(result=result, new_hash=state_hash, overwrite=True)
338
- return result
339
- else:
340
- return state.result
341
-
342
- async def jam(
343
- self,
344
- *,
345
- user_message: UserMessage,
346
- context: SystemMessage | None = None,
347
- index: int = 0,
348
- ) -> R:
349
- flow_run = self.__get_current_flow_run()
350
- node_state = flow_run.get(task=self, index=index)
351
-
352
- if context is not None:
353
- node_state.set_context(context=context)
354
- node_state.add_user_message(message=user_message)
355
-
356
- result = await self.__adapt(
357
- conversation=node_state.conversation, agent=flow_run.agent
358
- )
359
- node_state.add_result(result=result)
360
- flow_run.add_node_state(node_state)
361
-
362
- return result
363
-
364
- async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
365
- flow_run = self.__get_current_flow_run()
366
- node_state = flow_run.get_next(task=self)
367
- result = await self.generate(node_state, *args, **kwargs)
368
- flow_run.add_node_state(node_state)
369
- return result
370
-
371
- async def __adapt(self, *, conversation: Conversation[R], agent: Agent) -> R:
372
- messages: list[UserMessage | AssistantMessage] = []
373
- for message_index in range(len(conversation.user_messages)):
374
- user_message = conversation.user_messages[message_index]
375
- result = conversation.result_messages[message_index]
376
-
377
- if isinstance(result, TextResult):
378
- assistant_text = result.text
379
- else:
380
- assistant_text = result.model_dump_json()
381
- assistant_message = AssistantMessage(text=assistant_text)
382
- messages.append(assistant_message)
383
- messages.append(user_message)
384
-
385
- return await agent(
386
- messages=messages,
387
- model=self._adapter_model,
388
- task_name=f"adapt--{self.name}",
389
- system=conversation.context,
390
- response_model=self.result_type,
391
- )
392
-
393
- def __hash_task_call(self, *args: P.args, **kwargs: P.kwargs) -> int:
394
- try:
395
- to_hash = str(
396
- tuple(args)
397
- + tuple(kwargs.values())
398
- + (self._generator.__code__, self._adapter_model)
399
- )
400
- return hash(to_hash)
401
- except TypeError:
402
- raise Honk(f"Unhashable argument to task {self.name}: {args} {kwargs}")
403
-
404
- def __get_current_flow_run(self) -> FlowRun:
405
- run = _current_flow_run.get()
406
- if run is None:
407
- raise Honk("No current flow run")
408
- return run
409
-
410
-
411
- @overload
412
- def task[**P, R: Result](generator: Callable[P, Awaitable[R]], /) -> Task[P, R]: ...
413
- @overload
414
- def task[**P, R: Result](
415
- *, retries: int = 0, adapter_model: GeminiModel = GeminiModel.FLASH
416
- ) -> Callable[[Callable[P, Awaitable[R]]], Task[P, R]]: ...
417
- def task[**P, R: Result](
418
- generator: Callable[P, Awaitable[R]] | None = None,
419
- /,
420
- *,
421
- retries: int = 0,
422
- adapter_model: GeminiModel = GeminiModel.FLASH,
423
- ) -> Task[P, R] | Callable[[Callable[P, Awaitable[R]]], Task[P, R]]:
424
- if generator is None:
425
-
426
- def decorator(fn: Callable[P, Awaitable[R]]) -> Task[P, R]:
427
- return Task(fn, retries=retries, adapter_model=adapter_model)
428
-
429
- return decorator
430
-
431
- return Task(generator, retries=retries, adapter_model=adapter_model)
432
-
433
-
434
- @overload
435
- def flow[**P](fn: Callable[P, Awaitable[None]], /) -> Flow[P]: ...
436
- @overload
437
- def flow[**P](
438
- *,
439
- name: str | None = None,
440
- store: IFlowRunStore | None = None,
441
- agent_logger: IAgentLogger | None = None,
442
- ) -> Callable[[Callable[P, Awaitable[None]]], Flow[P]]: ...
443
- def flow[**P](
444
- fn: Callable[P, Awaitable[None]] | None = None,
445
- /,
446
- *,
447
- name: str | None = None,
448
- store: IFlowRunStore | None = None,
449
- agent_logger: IAgentLogger | None = None,
450
- ) -> Flow[P] | Callable[[Callable[P, Awaitable[None]]], Flow[P]]:
451
- if fn is None:
452
-
453
- def decorator(fn: Callable[P, Awaitable[None]]) -> Flow[P]:
454
- return Flow(fn, name=name, store=store, agent_logger=agent_logger)
455
-
456
- return decorator
457
-
458
- return Flow(fn, name=name, store=store, agent_logger=agent_logger)
File without changes
File without changes
File without changes