goose-py 0.5.1__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- goose/__init__.py +4 -0
- goose/_internal/agent.py +201 -0
- goose/_internal/conversation.py +41 -0
- goose/_internal/flow.py +106 -0
- goose/_internal/state.py +190 -0
- goose/{store.py → _internal/store.py} +8 -6
- goose/_internal/task.py +136 -0
- goose/_internal/types/__init__.py +0 -0
- goose/_internal/types/agent.py +92 -0
- goose/agent.py +28 -283
- goose/flow.py +2 -457
- goose/runs.py +4 -0
- goose_py-0.7.0.dist-info/METADATA +14 -0
- goose_py-0.7.0.dist-info/RECORD +18 -0
- {goose_py-0.5.1.dist-info → goose_py-0.7.0.dist-info}/WHEEL +1 -1
- goose_py-0.5.1.dist-info/METADATA +0 -31
- goose_py-0.5.1.dist-info/RECORD +0 -10
- /goose/{result.py → _internal/result.py} +0 -0
goose/_internal/task.py
ADDED
@@ -0,0 +1,136 @@
|
|
1
|
+
from typing import Awaitable, Callable, overload
|
2
|
+
|
3
|
+
from goose._internal.agent import Agent, GeminiModel, SystemMessage, UserMessage
|
4
|
+
from goose._internal.conversation import Conversation
|
5
|
+
from goose._internal.result import Result, TextResult
|
6
|
+
from goose._internal.state import FlowRun, NodeState, get_current_flow_run
|
7
|
+
from goose._internal.types.agent import AssistantMessage
|
8
|
+
from goose.errors import Honk
|
9
|
+
|
10
|
+
|
11
|
+
class Task[**P, R: Result]:
|
12
|
+
def __init__(
|
13
|
+
self,
|
14
|
+
generator: Callable[P, Awaitable[R]],
|
15
|
+
/,
|
16
|
+
*,
|
17
|
+
retries: int = 0,
|
18
|
+
adapter_model: GeminiModel = GeminiModel.FLASH,
|
19
|
+
) -> None:
|
20
|
+
self._generator = generator
|
21
|
+
self._retries = retries
|
22
|
+
self._adapter_model = adapter_model
|
23
|
+
self._adapter_model = adapter_model
|
24
|
+
|
25
|
+
@property
|
26
|
+
def result_type(self) -> type[R]:
|
27
|
+
result_type = self._generator.__annotations__.get("return")
|
28
|
+
if result_type is None:
|
29
|
+
raise Honk(f"Task {self.name} has no return type annotation")
|
30
|
+
return result_type
|
31
|
+
|
32
|
+
@property
|
33
|
+
def name(self) -> str:
|
34
|
+
return self._generator.__name__
|
35
|
+
|
36
|
+
async def generate(
|
37
|
+
self, state: NodeState[R], *args: P.args, **kwargs: P.kwargs
|
38
|
+
) -> R:
|
39
|
+
state_hash = self.__hash_task_call(*args, **kwargs)
|
40
|
+
if state_hash != state.last_hash:
|
41
|
+
result = await self._generator(*args, **kwargs)
|
42
|
+
state.add_result(result=result, new_hash=state_hash, overwrite=True)
|
43
|
+
return result
|
44
|
+
else:
|
45
|
+
return state.result
|
46
|
+
|
47
|
+
async def jam(
|
48
|
+
self,
|
49
|
+
*,
|
50
|
+
user_message: UserMessage,
|
51
|
+
context: SystemMessage | None = None,
|
52
|
+
index: int = 0,
|
53
|
+
) -> R:
|
54
|
+
flow_run = self.__get_current_flow_run()
|
55
|
+
node_state = flow_run.get(task=self, index=index)
|
56
|
+
|
57
|
+
if context is not None:
|
58
|
+
node_state.set_context(context=context)
|
59
|
+
node_state.add_user_message(message=user_message)
|
60
|
+
|
61
|
+
result = await self.__adapt(
|
62
|
+
conversation=node_state.conversation, agent=flow_run.agent
|
63
|
+
)
|
64
|
+
node_state.add_result(result=result)
|
65
|
+
flow_run.add_node_state(node_state)
|
66
|
+
|
67
|
+
return result
|
68
|
+
|
69
|
+
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
70
|
+
flow_run = self.__get_current_flow_run()
|
71
|
+
node_state = flow_run.get_next(task=self)
|
72
|
+
result = await self.generate(node_state, *args, **kwargs)
|
73
|
+
flow_run.add_node_state(node_state)
|
74
|
+
return result
|
75
|
+
|
76
|
+
async def __adapt(self, *, conversation: Conversation[R], agent: Agent) -> R:
|
77
|
+
messages: list[UserMessage | AssistantMessage] = []
|
78
|
+
for message_index in range(len(conversation.user_messages)):
|
79
|
+
user_message = conversation.user_messages[message_index]
|
80
|
+
result = conversation.result_messages[message_index]
|
81
|
+
|
82
|
+
if isinstance(result, TextResult):
|
83
|
+
assistant_text = result.text
|
84
|
+
else:
|
85
|
+
assistant_text = result.model_dump_json()
|
86
|
+
assistant_message = AssistantMessage(text=assistant_text)
|
87
|
+
messages.append(assistant_message)
|
88
|
+
messages.append(user_message)
|
89
|
+
|
90
|
+
return await agent(
|
91
|
+
messages=messages,
|
92
|
+
model=self._adapter_model,
|
93
|
+
task_name=f"adapt--{self.name}",
|
94
|
+
system=conversation.context,
|
95
|
+
response_model=self.result_type,
|
96
|
+
)
|
97
|
+
|
98
|
+
def __hash_task_call(self, *args: P.args, **kwargs: P.kwargs) -> int:
|
99
|
+
try:
|
100
|
+
to_hash = str(
|
101
|
+
tuple(args)
|
102
|
+
+ tuple(kwargs.values())
|
103
|
+
+ (self._generator.__code__, self._adapter_model)
|
104
|
+
)
|
105
|
+
return hash(to_hash)
|
106
|
+
except TypeError:
|
107
|
+
raise Honk(f"Unhashable argument to task {self.name}: {args} {kwargs}")
|
108
|
+
|
109
|
+
def __get_current_flow_run(self) -> FlowRun:
|
110
|
+
run = get_current_flow_run()
|
111
|
+
if run is None:
|
112
|
+
raise Honk("No current flow run")
|
113
|
+
return run
|
114
|
+
|
115
|
+
|
116
|
+
@overload
|
117
|
+
def task[**P, R: Result](generator: Callable[P, Awaitable[R]], /) -> Task[P, R]: ...
|
118
|
+
@overload
|
119
|
+
def task[**P, R: Result](
|
120
|
+
*, retries: int = 0, adapter_model: GeminiModel = GeminiModel.FLASH
|
121
|
+
) -> Callable[[Callable[P, Awaitable[R]]], Task[P, R]]: ...
|
122
|
+
def task[**P, R: Result](
|
123
|
+
generator: Callable[P, Awaitable[R]] | None = None,
|
124
|
+
/,
|
125
|
+
*,
|
126
|
+
retries: int = 0,
|
127
|
+
adapter_model: GeminiModel = GeminiModel.FLASH,
|
128
|
+
) -> Task[P, R] | Callable[[Callable[P, Awaitable[R]]], Task[P, R]]:
|
129
|
+
if generator is None:
|
130
|
+
|
131
|
+
def decorator(fn: Callable[P, Awaitable[R]]) -> Task[P, R]:
|
132
|
+
return Task(fn, retries=retries, adapter_model=adapter_model)
|
133
|
+
|
134
|
+
return decorator
|
135
|
+
|
136
|
+
return Task(generator, retries=retries, adapter_model=adapter_model)
|
File without changes
|
@@ -0,0 +1,92 @@
|
|
1
|
+
from enum import StrEnum
|
2
|
+
from typing import Literal, NotRequired, TypedDict
|
3
|
+
|
4
|
+
from pydantic import BaseModel
|
5
|
+
|
6
|
+
|
7
|
+
class GeminiModel(StrEnum):
|
8
|
+
PRO = "vertex_ai/gemini-1.5-pro"
|
9
|
+
FLASH = "vertex_ai/gemini-1.5-flash"
|
10
|
+
FLASH_8B = "vertex_ai/gemini-1.5-flash-8b"
|
11
|
+
|
12
|
+
|
13
|
+
class UserMediaContentType(StrEnum):
|
14
|
+
# images
|
15
|
+
JPEG = "image/jpeg"
|
16
|
+
PNG = "image/png"
|
17
|
+
WEBP = "image/webp"
|
18
|
+
|
19
|
+
# audio
|
20
|
+
MP3 = "audio/mp3"
|
21
|
+
WAV = "audio/wav"
|
22
|
+
|
23
|
+
# files
|
24
|
+
PDF = "application/pdf"
|
25
|
+
|
26
|
+
|
27
|
+
class LLMTextMessagePart(TypedDict):
|
28
|
+
type: Literal["text"]
|
29
|
+
text: str
|
30
|
+
|
31
|
+
|
32
|
+
class LLMMediaMessagePart(TypedDict):
|
33
|
+
type: Literal["image_url"]
|
34
|
+
image_url: str
|
35
|
+
|
36
|
+
|
37
|
+
class CacheControl(TypedDict):
|
38
|
+
type: Literal["ephemeral"]
|
39
|
+
|
40
|
+
|
41
|
+
class LLMMessage(TypedDict):
|
42
|
+
role: Literal["user", "assistant", "system"]
|
43
|
+
content: list[LLMTextMessagePart | LLMMediaMessagePart]
|
44
|
+
cache_control: NotRequired[CacheControl]
|
45
|
+
|
46
|
+
|
47
|
+
class TextMessagePart(BaseModel):
|
48
|
+
text: str
|
49
|
+
|
50
|
+
def render(self) -> LLMTextMessagePart:
|
51
|
+
return {"type": "text", "text": self.text}
|
52
|
+
|
53
|
+
|
54
|
+
class MediaMessagePart(BaseModel):
|
55
|
+
content_type: UserMediaContentType
|
56
|
+
content: str
|
57
|
+
|
58
|
+
def render(self) -> LLMMediaMessagePart:
|
59
|
+
return {
|
60
|
+
"type": "image_url",
|
61
|
+
"image_url": f"data:{self.content_type};base64,{self.content}",
|
62
|
+
}
|
63
|
+
|
64
|
+
|
65
|
+
class UserMessage(BaseModel):
|
66
|
+
parts: list[TextMessagePart | MediaMessagePart]
|
67
|
+
|
68
|
+
def render(self) -> LLMMessage:
|
69
|
+
content: LLMMessage = {
|
70
|
+
"role": "user",
|
71
|
+
"content": [part.render() for part in self.parts],
|
72
|
+
}
|
73
|
+
if any(isinstance(part, MediaMessagePart) for part in self.parts):
|
74
|
+
content["cache_control"] = {"type": "ephemeral"}
|
75
|
+
return content
|
76
|
+
|
77
|
+
|
78
|
+
class AssistantMessage(BaseModel):
|
79
|
+
text: str
|
80
|
+
|
81
|
+
def render(self) -> LLMMessage:
|
82
|
+
return {"role": "assistant", "content": [{"type": "text", "text": self.text}]}
|
83
|
+
|
84
|
+
|
85
|
+
class SystemMessage(BaseModel):
|
86
|
+
parts: list[TextMessagePart | MediaMessagePart]
|
87
|
+
|
88
|
+
def render(self) -> LLMMessage:
|
89
|
+
return {
|
90
|
+
"role": "system",
|
91
|
+
"content": [part.render() for part in self.parts],
|
92
|
+
}
|
goose/agent.py
CHANGED
@@ -1,283 +1,28 @@
|
|
1
|
-
import
|
2
|
-
import
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
PDF = "application/pdf"
|
30
|
-
|
31
|
-
|
32
|
-
class LLMTextMessagePart(TypedDict):
|
33
|
-
type: Literal["text"]
|
34
|
-
text: str
|
35
|
-
|
36
|
-
|
37
|
-
class LLMMediaMessagePart(TypedDict):
|
38
|
-
type: Literal["image_url"]
|
39
|
-
image_url: str
|
40
|
-
|
41
|
-
|
42
|
-
class CacheControl(TypedDict):
|
43
|
-
type: Literal["ephemeral"]
|
44
|
-
|
45
|
-
|
46
|
-
class LLMMessage(TypedDict):
|
47
|
-
role: Literal["user", "assistant", "system"]
|
48
|
-
content: list[LLMTextMessagePart | LLMMediaMessagePart]
|
49
|
-
cache_control: NotRequired[CacheControl]
|
50
|
-
|
51
|
-
|
52
|
-
class TextMessagePart(BaseModel):
|
53
|
-
text: str
|
54
|
-
|
55
|
-
def render(self) -> LLMTextMessagePart:
|
56
|
-
return {"type": "text", "text": self.text}
|
57
|
-
|
58
|
-
|
59
|
-
class MediaMessagePart(BaseModel):
|
60
|
-
content_type: UserMediaContentType
|
61
|
-
content: str
|
62
|
-
|
63
|
-
def render(self) -> LLMMediaMessagePart:
|
64
|
-
return {
|
65
|
-
"type": "image_url",
|
66
|
-
"image_url": f"data:{self.content_type};base64,{self.content}",
|
67
|
-
}
|
68
|
-
|
69
|
-
|
70
|
-
class UserMessage(BaseModel):
|
71
|
-
parts: list[TextMessagePart | MediaMessagePart]
|
72
|
-
|
73
|
-
def render(self) -> LLMMessage:
|
74
|
-
content: LLMMessage = {
|
75
|
-
"role": "user",
|
76
|
-
"content": [part.render() for part in self.parts],
|
77
|
-
}
|
78
|
-
if any(isinstance(part, MediaMessagePart) for part in self.parts):
|
79
|
-
content["cache_control"] = {"type": "ephemeral"}
|
80
|
-
return content
|
81
|
-
|
82
|
-
|
83
|
-
class AssistantMessage(BaseModel):
|
84
|
-
text: str
|
85
|
-
|
86
|
-
def render(self) -> LLMMessage:
|
87
|
-
return {"role": "assistant", "content": [{"type": "text", "text": self.text}]}
|
88
|
-
|
89
|
-
|
90
|
-
class SystemMessage(BaseModel):
|
91
|
-
parts: list[TextMessagePart | MediaMessagePart]
|
92
|
-
|
93
|
-
def render(self) -> LLMMessage:
|
94
|
-
return {
|
95
|
-
"role": "system",
|
96
|
-
"content": [part.render() for part in self.parts],
|
97
|
-
}
|
98
|
-
|
99
|
-
|
100
|
-
class AgentResponseDump(TypedDict):
|
101
|
-
run_id: str
|
102
|
-
flow_name: str
|
103
|
-
task_name: str
|
104
|
-
model: str
|
105
|
-
system_message: str
|
106
|
-
input_messages: list[str]
|
107
|
-
output_message: str
|
108
|
-
input_cost: float
|
109
|
-
output_cost: float
|
110
|
-
total_cost: float
|
111
|
-
input_tokens: int
|
112
|
-
output_tokens: int
|
113
|
-
start_time: datetime
|
114
|
-
end_time: datetime
|
115
|
-
duration_ms: int
|
116
|
-
|
117
|
-
|
118
|
-
class AgentResponse[R: BaseModel | str](BaseModel):
|
119
|
-
INPUT_CENTS_PER_MILLION_TOKENS: ClassVar[dict[GeminiModel, float]] = {
|
120
|
-
GeminiModel.FLASH_8B: 30,
|
121
|
-
GeminiModel.FLASH: 15,
|
122
|
-
GeminiModel.PRO: 500,
|
123
|
-
}
|
124
|
-
OUTPUT_CENTS_PER_MILLION_TOKENS: ClassVar[dict[GeminiModel, float]] = {
|
125
|
-
GeminiModel.FLASH_8B: 30,
|
126
|
-
GeminiModel.FLASH: 15,
|
127
|
-
GeminiModel.PRO: 500,
|
128
|
-
}
|
129
|
-
|
130
|
-
response: R
|
131
|
-
run_id: str
|
132
|
-
flow_name: str
|
133
|
-
task_name: str
|
134
|
-
model: GeminiModel
|
135
|
-
system: SystemMessage | None = None
|
136
|
-
input_messages: list[UserMessage | AssistantMessage]
|
137
|
-
input_tokens: int
|
138
|
-
output_tokens: int
|
139
|
-
start_time: datetime
|
140
|
-
end_time: datetime
|
141
|
-
|
142
|
-
@computed_field
|
143
|
-
@property
|
144
|
-
def duration_ms(self) -> int:
|
145
|
-
return int((self.end_time - self.start_time).total_seconds() * 1000)
|
146
|
-
|
147
|
-
@computed_field
|
148
|
-
@property
|
149
|
-
def input_cost(self) -> float:
|
150
|
-
return (
|
151
|
-
self.INPUT_CENTS_PER_MILLION_TOKENS[self.model]
|
152
|
-
* self.input_tokens
|
153
|
-
/ 1_000_000
|
154
|
-
)
|
155
|
-
|
156
|
-
@computed_field
|
157
|
-
@property
|
158
|
-
def output_cost(self) -> float:
|
159
|
-
return (
|
160
|
-
self.OUTPUT_CENTS_PER_MILLION_TOKENS[self.model]
|
161
|
-
* self.output_tokens
|
162
|
-
/ 1_000_000
|
163
|
-
)
|
164
|
-
|
165
|
-
@computed_field
|
166
|
-
@property
|
167
|
-
def total_cost(self) -> float:
|
168
|
-
return self.input_cost + self.output_cost
|
169
|
-
|
170
|
-
def minimized_dump(self) -> AgentResponseDump:
|
171
|
-
if self.system is None:
|
172
|
-
minimized_system_message = ""
|
173
|
-
else:
|
174
|
-
minimized_system_message = self.system.render()
|
175
|
-
for part in minimized_system_message["content"]:
|
176
|
-
if part["type"] == "image_url":
|
177
|
-
part["image_url"] = "__MEDIA__"
|
178
|
-
minimized_system_message = json.dumps(minimized_system_message)
|
179
|
-
|
180
|
-
minimized_input_messages = [message.render() for message in self.input_messages]
|
181
|
-
for message in minimized_input_messages:
|
182
|
-
for part in message["content"]:
|
183
|
-
if part["type"] == "image_url":
|
184
|
-
part["image_url"] = "__MEDIA__"
|
185
|
-
minimized_input_messages = [
|
186
|
-
json.dumps(message) for message in minimized_input_messages
|
187
|
-
]
|
188
|
-
|
189
|
-
output_message = (
|
190
|
-
self.response.model_dump_json()
|
191
|
-
if isinstance(self.response, BaseModel)
|
192
|
-
else self.response
|
193
|
-
)
|
194
|
-
|
195
|
-
return {
|
196
|
-
"run_id": self.run_id,
|
197
|
-
"flow_name": self.flow_name,
|
198
|
-
"task_name": self.task_name,
|
199
|
-
"model": self.model.value,
|
200
|
-
"system_message": minimized_system_message,
|
201
|
-
"input_messages": minimized_input_messages,
|
202
|
-
"output_message": output_message,
|
203
|
-
"input_tokens": self.input_tokens,
|
204
|
-
"output_tokens": self.output_tokens,
|
205
|
-
"input_cost": self.input_cost,
|
206
|
-
"output_cost": self.output_cost,
|
207
|
-
"total_cost": self.total_cost,
|
208
|
-
"start_time": self.start_time,
|
209
|
-
"end_time": self.end_time,
|
210
|
-
"duration_ms": self.duration_ms,
|
211
|
-
}
|
212
|
-
|
213
|
-
|
214
|
-
class IAgentLogger(Protocol):
|
215
|
-
async def __call__(self, *, response: AgentResponse[Any]) -> None: ...
|
216
|
-
|
217
|
-
|
218
|
-
class Agent:
|
219
|
-
def __init__(
|
220
|
-
self,
|
221
|
-
*,
|
222
|
-
flow_name: str,
|
223
|
-
run_id: str,
|
224
|
-
logger: IAgentLogger | None = None,
|
225
|
-
) -> None:
|
226
|
-
self.flow_name = flow_name
|
227
|
-
self.run_id = run_id
|
228
|
-
self.logger = logger
|
229
|
-
|
230
|
-
async def __call__[R: Result](
|
231
|
-
self,
|
232
|
-
*,
|
233
|
-
messages: list[UserMessage | AssistantMessage],
|
234
|
-
model: GeminiModel,
|
235
|
-
task_name: str,
|
236
|
-
response_model: type[R] = TextResult,
|
237
|
-
system: SystemMessage | None = None,
|
238
|
-
) -> R:
|
239
|
-
start_time = datetime.now()
|
240
|
-
rendered_messages = [message.render() for message in messages]
|
241
|
-
if system is not None:
|
242
|
-
rendered_messages.insert(0, system.render())
|
243
|
-
|
244
|
-
if response_model is TextResult:
|
245
|
-
response = await acompletion(model=model.value, messages=rendered_messages)
|
246
|
-
parsed_response = response_model.model_validate(
|
247
|
-
{"text": response.choices[0].message.content}
|
248
|
-
)
|
249
|
-
else:
|
250
|
-
response = await acompletion(
|
251
|
-
model=model.value,
|
252
|
-
messages=rendered_messages,
|
253
|
-
response_format={
|
254
|
-
"type": "json_object",
|
255
|
-
"response_schema": response_model.model_json_schema(),
|
256
|
-
"enforce_validation": True,
|
257
|
-
},
|
258
|
-
)
|
259
|
-
parsed_response = response_model.model_validate_json(
|
260
|
-
response.choices[0].message.content
|
261
|
-
)
|
262
|
-
|
263
|
-
end_time = datetime.now()
|
264
|
-
agent_response = AgentResponse(
|
265
|
-
response=parsed_response,
|
266
|
-
run_id=self.run_id,
|
267
|
-
flow_name=self.flow_name,
|
268
|
-
task_name=task_name,
|
269
|
-
model=model,
|
270
|
-
system=system,
|
271
|
-
input_messages=messages,
|
272
|
-
input_tokens=response.usage.prompt_tokens,
|
273
|
-
output_tokens=response.usage.completion_tokens,
|
274
|
-
start_time=start_time,
|
275
|
-
end_time=end_time,
|
276
|
-
)
|
277
|
-
|
278
|
-
if self.logger is not None:
|
279
|
-
await self.logger(response=agent_response)
|
280
|
-
else:
|
281
|
-
logging.info(agent_response.model_dump())
|
282
|
-
|
283
|
-
return parsed_response
|
1
|
+
from goose._internal.agent import AgentResponse, IAgentLogger
|
2
|
+
from goose._internal.types.agent import (
|
3
|
+
AssistantMessage,
|
4
|
+
GeminiModel,
|
5
|
+
LLMMediaMessagePart,
|
6
|
+
LLMMessage,
|
7
|
+
LLMTextMessagePart,
|
8
|
+
MediaMessagePart,
|
9
|
+
SystemMessage,
|
10
|
+
TextMessagePart,
|
11
|
+
UserMediaContentType,
|
12
|
+
UserMessage,
|
13
|
+
)
|
14
|
+
|
15
|
+
__all__ = [
|
16
|
+
"AgentResponse",
|
17
|
+
"IAgentLogger",
|
18
|
+
"AssistantMessage",
|
19
|
+
"GeminiModel",
|
20
|
+
"LLMMediaMessagePart",
|
21
|
+
"LLMMessage",
|
22
|
+
"LLMTextMessagePart",
|
23
|
+
"MediaMessagePart",
|
24
|
+
"SystemMessage",
|
25
|
+
"TextMessagePart",
|
26
|
+
"UserMediaContentType",
|
27
|
+
"UserMessage",
|
28
|
+
]
|