goose-py 0.5.1__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,136 @@
1
+ from typing import Awaitable, Callable, overload
2
+
3
+ from goose._internal.agent import Agent, GeminiModel, SystemMessage, UserMessage
4
+ from goose._internal.conversation import Conversation
5
+ from goose._internal.result import Result, TextResult
6
+ from goose._internal.state import FlowRun, NodeState, get_current_flow_run
7
+ from goose._internal.types.agent import AssistantMessage
8
+ from goose.errors import Honk
9
+
10
+
11
+ class Task[**P, R: Result]:
12
+ def __init__(
13
+ self,
14
+ generator: Callable[P, Awaitable[R]],
15
+ /,
16
+ *,
17
+ retries: int = 0,
18
+ adapter_model: GeminiModel = GeminiModel.FLASH,
19
+ ) -> None:
20
+ self._generator = generator
21
+ self._retries = retries
22
+ self._adapter_model = adapter_model
23
+ self._adapter_model = adapter_model
24
+
25
+ @property
26
+ def result_type(self) -> type[R]:
27
+ result_type = self._generator.__annotations__.get("return")
28
+ if result_type is None:
29
+ raise Honk(f"Task {self.name} has no return type annotation")
30
+ return result_type
31
+
32
+ @property
33
+ def name(self) -> str:
34
+ return self._generator.__name__
35
+
36
+ async def generate(
37
+ self, state: NodeState[R], *args: P.args, **kwargs: P.kwargs
38
+ ) -> R:
39
+ state_hash = self.__hash_task_call(*args, **kwargs)
40
+ if state_hash != state.last_hash:
41
+ result = await self._generator(*args, **kwargs)
42
+ state.add_result(result=result, new_hash=state_hash, overwrite=True)
43
+ return result
44
+ else:
45
+ return state.result
46
+
47
+ async def jam(
48
+ self,
49
+ *,
50
+ user_message: UserMessage,
51
+ context: SystemMessage | None = None,
52
+ index: int = 0,
53
+ ) -> R:
54
+ flow_run = self.__get_current_flow_run()
55
+ node_state = flow_run.get(task=self, index=index)
56
+
57
+ if context is not None:
58
+ node_state.set_context(context=context)
59
+ node_state.add_user_message(message=user_message)
60
+
61
+ result = await self.__adapt(
62
+ conversation=node_state.conversation, agent=flow_run.agent
63
+ )
64
+ node_state.add_result(result=result)
65
+ flow_run.add_node_state(node_state)
66
+
67
+ return result
68
+
69
+ async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
70
+ flow_run = self.__get_current_flow_run()
71
+ node_state = flow_run.get_next(task=self)
72
+ result = await self.generate(node_state, *args, **kwargs)
73
+ flow_run.add_node_state(node_state)
74
+ return result
75
+
76
+ async def __adapt(self, *, conversation: Conversation[R], agent: Agent) -> R:
77
+ messages: list[UserMessage | AssistantMessage] = []
78
+ for message_index in range(len(conversation.user_messages)):
79
+ user_message = conversation.user_messages[message_index]
80
+ result = conversation.result_messages[message_index]
81
+
82
+ if isinstance(result, TextResult):
83
+ assistant_text = result.text
84
+ else:
85
+ assistant_text = result.model_dump_json()
86
+ assistant_message = AssistantMessage(text=assistant_text)
87
+ messages.append(assistant_message)
88
+ messages.append(user_message)
89
+
90
+ return await agent(
91
+ messages=messages,
92
+ model=self._adapter_model,
93
+ task_name=f"adapt--{self.name}",
94
+ system=conversation.context,
95
+ response_model=self.result_type,
96
+ )
97
+
98
+ def __hash_task_call(self, *args: P.args, **kwargs: P.kwargs) -> int:
99
+ try:
100
+ to_hash = str(
101
+ tuple(args)
102
+ + tuple(kwargs.values())
103
+ + (self._generator.__code__, self._adapter_model)
104
+ )
105
+ return hash(to_hash)
106
+ except TypeError:
107
+ raise Honk(f"Unhashable argument to task {self.name}: {args} {kwargs}")
108
+
109
+ def __get_current_flow_run(self) -> FlowRun:
110
+ run = get_current_flow_run()
111
+ if run is None:
112
+ raise Honk("No current flow run")
113
+ return run
114
+
115
+
116
+ @overload
117
+ def task[**P, R: Result](generator: Callable[P, Awaitable[R]], /) -> Task[P, R]: ...
118
+ @overload
119
+ def task[**P, R: Result](
120
+ *, retries: int = 0, adapter_model: GeminiModel = GeminiModel.FLASH
121
+ ) -> Callable[[Callable[P, Awaitable[R]]], Task[P, R]]: ...
122
+ def task[**P, R: Result](
123
+ generator: Callable[P, Awaitable[R]] | None = None,
124
+ /,
125
+ *,
126
+ retries: int = 0,
127
+ adapter_model: GeminiModel = GeminiModel.FLASH,
128
+ ) -> Task[P, R] | Callable[[Callable[P, Awaitable[R]]], Task[P, R]]:
129
+ if generator is None:
130
+
131
+ def decorator(fn: Callable[P, Awaitable[R]]) -> Task[P, R]:
132
+ return Task(fn, retries=retries, adapter_model=adapter_model)
133
+
134
+ return decorator
135
+
136
+ return Task(generator, retries=retries, adapter_model=adapter_model)
File without changes
@@ -0,0 +1,92 @@
1
+ from enum import StrEnum
2
+ from typing import Literal, NotRequired, TypedDict
3
+
4
+ from pydantic import BaseModel
5
+
6
+
7
+ class GeminiModel(StrEnum):
8
+ PRO = "vertex_ai/gemini-1.5-pro"
9
+ FLASH = "vertex_ai/gemini-1.5-flash"
10
+ FLASH_8B = "vertex_ai/gemini-1.5-flash-8b"
11
+
12
+
13
+ class UserMediaContentType(StrEnum):
14
+ # images
15
+ JPEG = "image/jpeg"
16
+ PNG = "image/png"
17
+ WEBP = "image/webp"
18
+
19
+ # audio
20
+ MP3 = "audio/mp3"
21
+ WAV = "audio/wav"
22
+
23
+ # files
24
+ PDF = "application/pdf"
25
+
26
+
27
+ class LLMTextMessagePart(TypedDict):
28
+ type: Literal["text"]
29
+ text: str
30
+
31
+
32
+ class LLMMediaMessagePart(TypedDict):
33
+ type: Literal["image_url"]
34
+ image_url: str
35
+
36
+
37
+ class CacheControl(TypedDict):
38
+ type: Literal["ephemeral"]
39
+
40
+
41
+ class LLMMessage(TypedDict):
42
+ role: Literal["user", "assistant", "system"]
43
+ content: list[LLMTextMessagePart | LLMMediaMessagePart]
44
+ cache_control: NotRequired[CacheControl]
45
+
46
+
47
+ class TextMessagePart(BaseModel):
48
+ text: str
49
+
50
+ def render(self) -> LLMTextMessagePart:
51
+ return {"type": "text", "text": self.text}
52
+
53
+
54
+ class MediaMessagePart(BaseModel):
55
+ content_type: UserMediaContentType
56
+ content: str
57
+
58
+ def render(self) -> LLMMediaMessagePart:
59
+ return {
60
+ "type": "image_url",
61
+ "image_url": f"data:{self.content_type};base64,{self.content}",
62
+ }
63
+
64
+
65
+ class UserMessage(BaseModel):
66
+ parts: list[TextMessagePart | MediaMessagePart]
67
+
68
+ def render(self) -> LLMMessage:
69
+ content: LLMMessage = {
70
+ "role": "user",
71
+ "content": [part.render() for part in self.parts],
72
+ }
73
+ if any(isinstance(part, MediaMessagePart) for part in self.parts):
74
+ content["cache_control"] = {"type": "ephemeral"}
75
+ return content
76
+
77
+
78
+ class AssistantMessage(BaseModel):
79
+ text: str
80
+
81
+ def render(self) -> LLMMessage:
82
+ return {"role": "assistant", "content": [{"type": "text", "text": self.text}]}
83
+
84
+
85
+ class SystemMessage(BaseModel):
86
+ parts: list[TextMessagePart | MediaMessagePart]
87
+
88
+ def render(self) -> LLMMessage:
89
+ return {
90
+ "role": "system",
91
+ "content": [part.render() for part in self.parts],
92
+ }
goose/agent.py CHANGED
@@ -1,283 +1,28 @@
1
- import json
2
- import logging
3
- from datetime import datetime
4
- from enum import StrEnum
5
- from typing import Any, ClassVar, Literal, NotRequired, Protocol, TypedDict
6
-
7
- from litellm import acompletion
8
- from pydantic import BaseModel, computed_field
9
- from goose.result import Result, TextResult
10
-
11
-
12
- class GeminiModel(StrEnum):
13
- PRO = "vertex_ai/gemini-1.5-pro"
14
- FLASH = "vertex_ai/gemini-1.5-flash"
15
- FLASH_8B = "vertex_ai/gemini-1.5-flash-8b"
16
-
17
-
18
- class UserMediaContentType(StrEnum):
19
- # images
20
- JPEG = "image/jpeg"
21
- PNG = "image/png"
22
- WEBP = "image/webp"
23
-
24
- # audio
25
- MP3 = "audio/mp3"
26
- WAV = "audio/wav"
27
-
28
- # files
29
- PDF = "application/pdf"
30
-
31
-
32
- class LLMTextMessagePart(TypedDict):
33
- type: Literal["text"]
34
- text: str
35
-
36
-
37
- class LLMMediaMessagePart(TypedDict):
38
- type: Literal["image_url"]
39
- image_url: str
40
-
41
-
42
- class CacheControl(TypedDict):
43
- type: Literal["ephemeral"]
44
-
45
-
46
- class LLMMessage(TypedDict):
47
- role: Literal["user", "assistant", "system"]
48
- content: list[LLMTextMessagePart | LLMMediaMessagePart]
49
- cache_control: NotRequired[CacheControl]
50
-
51
-
52
- class TextMessagePart(BaseModel):
53
- text: str
54
-
55
- def render(self) -> LLMTextMessagePart:
56
- return {"type": "text", "text": self.text}
57
-
58
-
59
- class MediaMessagePart(BaseModel):
60
- content_type: UserMediaContentType
61
- content: str
62
-
63
- def render(self) -> LLMMediaMessagePart:
64
- return {
65
- "type": "image_url",
66
- "image_url": f"data:{self.content_type};base64,{self.content}",
67
- }
68
-
69
-
70
- class UserMessage(BaseModel):
71
- parts: list[TextMessagePart | MediaMessagePart]
72
-
73
- def render(self) -> LLMMessage:
74
- content: LLMMessage = {
75
- "role": "user",
76
- "content": [part.render() for part in self.parts],
77
- }
78
- if any(isinstance(part, MediaMessagePart) for part in self.parts):
79
- content["cache_control"] = {"type": "ephemeral"}
80
- return content
81
-
82
-
83
- class AssistantMessage(BaseModel):
84
- text: str
85
-
86
- def render(self) -> LLMMessage:
87
- return {"role": "assistant", "content": [{"type": "text", "text": self.text}]}
88
-
89
-
90
- class SystemMessage(BaseModel):
91
- parts: list[TextMessagePart | MediaMessagePart]
92
-
93
- def render(self) -> LLMMessage:
94
- return {
95
- "role": "system",
96
- "content": [part.render() for part in self.parts],
97
- }
98
-
99
-
100
- class AgentResponseDump(TypedDict):
101
- run_id: str
102
- flow_name: str
103
- task_name: str
104
- model: str
105
- system_message: str
106
- input_messages: list[str]
107
- output_message: str
108
- input_cost: float
109
- output_cost: float
110
- total_cost: float
111
- input_tokens: int
112
- output_tokens: int
113
- start_time: datetime
114
- end_time: datetime
115
- duration_ms: int
116
-
117
-
118
- class AgentResponse[R: BaseModel | str](BaseModel):
119
- INPUT_CENTS_PER_MILLION_TOKENS: ClassVar[dict[GeminiModel, float]] = {
120
- GeminiModel.FLASH_8B: 30,
121
- GeminiModel.FLASH: 15,
122
- GeminiModel.PRO: 500,
123
- }
124
- OUTPUT_CENTS_PER_MILLION_TOKENS: ClassVar[dict[GeminiModel, float]] = {
125
- GeminiModel.FLASH_8B: 30,
126
- GeminiModel.FLASH: 15,
127
- GeminiModel.PRO: 500,
128
- }
129
-
130
- response: R
131
- run_id: str
132
- flow_name: str
133
- task_name: str
134
- model: GeminiModel
135
- system: SystemMessage | None = None
136
- input_messages: list[UserMessage | AssistantMessage]
137
- input_tokens: int
138
- output_tokens: int
139
- start_time: datetime
140
- end_time: datetime
141
-
142
- @computed_field
143
- @property
144
- def duration_ms(self) -> int:
145
- return int((self.end_time - self.start_time).total_seconds() * 1000)
146
-
147
- @computed_field
148
- @property
149
- def input_cost(self) -> float:
150
- return (
151
- self.INPUT_CENTS_PER_MILLION_TOKENS[self.model]
152
- * self.input_tokens
153
- / 1_000_000
154
- )
155
-
156
- @computed_field
157
- @property
158
- def output_cost(self) -> float:
159
- return (
160
- self.OUTPUT_CENTS_PER_MILLION_TOKENS[self.model]
161
- * self.output_tokens
162
- / 1_000_000
163
- )
164
-
165
- @computed_field
166
- @property
167
- def total_cost(self) -> float:
168
- return self.input_cost + self.output_cost
169
-
170
- def minimized_dump(self) -> AgentResponseDump:
171
- if self.system is None:
172
- minimized_system_message = ""
173
- else:
174
- minimized_system_message = self.system.render()
175
- for part in minimized_system_message["content"]:
176
- if part["type"] == "image_url":
177
- part["image_url"] = "__MEDIA__"
178
- minimized_system_message = json.dumps(minimized_system_message)
179
-
180
- minimized_input_messages = [message.render() for message in self.input_messages]
181
- for message in minimized_input_messages:
182
- for part in message["content"]:
183
- if part["type"] == "image_url":
184
- part["image_url"] = "__MEDIA__"
185
- minimized_input_messages = [
186
- json.dumps(message) for message in minimized_input_messages
187
- ]
188
-
189
- output_message = (
190
- self.response.model_dump_json()
191
- if isinstance(self.response, BaseModel)
192
- else self.response
193
- )
194
-
195
- return {
196
- "run_id": self.run_id,
197
- "flow_name": self.flow_name,
198
- "task_name": self.task_name,
199
- "model": self.model.value,
200
- "system_message": minimized_system_message,
201
- "input_messages": minimized_input_messages,
202
- "output_message": output_message,
203
- "input_tokens": self.input_tokens,
204
- "output_tokens": self.output_tokens,
205
- "input_cost": self.input_cost,
206
- "output_cost": self.output_cost,
207
- "total_cost": self.total_cost,
208
- "start_time": self.start_time,
209
- "end_time": self.end_time,
210
- "duration_ms": self.duration_ms,
211
- }
212
-
213
-
214
- class IAgentLogger(Protocol):
215
- async def __call__(self, *, response: AgentResponse[Any]) -> None: ...
216
-
217
-
218
- class Agent:
219
- def __init__(
220
- self,
221
- *,
222
- flow_name: str,
223
- run_id: str,
224
- logger: IAgentLogger | None = None,
225
- ) -> None:
226
- self.flow_name = flow_name
227
- self.run_id = run_id
228
- self.logger = logger
229
-
230
- async def __call__[R: Result](
231
- self,
232
- *,
233
- messages: list[UserMessage | AssistantMessage],
234
- model: GeminiModel,
235
- task_name: str,
236
- response_model: type[R] = TextResult,
237
- system: SystemMessage | None = None,
238
- ) -> R:
239
- start_time = datetime.now()
240
- rendered_messages = [message.render() for message in messages]
241
- if system is not None:
242
- rendered_messages.insert(0, system.render())
243
-
244
- if response_model is TextResult:
245
- response = await acompletion(model=model.value, messages=rendered_messages)
246
- parsed_response = response_model.model_validate(
247
- {"text": response.choices[0].message.content}
248
- )
249
- else:
250
- response = await acompletion(
251
- model=model.value,
252
- messages=rendered_messages,
253
- response_format={
254
- "type": "json_object",
255
- "response_schema": response_model.model_json_schema(),
256
- "enforce_validation": True,
257
- },
258
- )
259
- parsed_response = response_model.model_validate_json(
260
- response.choices[0].message.content
261
- )
262
-
263
- end_time = datetime.now()
264
- agent_response = AgentResponse(
265
- response=parsed_response,
266
- run_id=self.run_id,
267
- flow_name=self.flow_name,
268
- task_name=task_name,
269
- model=model,
270
- system=system,
271
- input_messages=messages,
272
- input_tokens=response.usage.prompt_tokens,
273
- output_tokens=response.usage.completion_tokens,
274
- start_time=start_time,
275
- end_time=end_time,
276
- )
277
-
278
- if self.logger is not None:
279
- await self.logger(response=agent_response)
280
- else:
281
- logging.info(agent_response.model_dump())
282
-
283
- return parsed_response
1
+ from goose._internal.agent import AgentResponse, IAgentLogger
2
+ from goose._internal.types.agent import (
3
+ AssistantMessage,
4
+ GeminiModel,
5
+ LLMMediaMessagePart,
6
+ LLMMessage,
7
+ LLMTextMessagePart,
8
+ MediaMessagePart,
9
+ SystemMessage,
10
+ TextMessagePart,
11
+ UserMediaContentType,
12
+ UserMessage,
13
+ )
14
+
15
+ __all__ = [
16
+ "AgentResponse",
17
+ "IAgentLogger",
18
+ "AssistantMessage",
19
+ "GeminiModel",
20
+ "LLMMediaMessagePart",
21
+ "LLMMessage",
22
+ "LLMTextMessagePart",
23
+ "MediaMessagePart",
24
+ "SystemMessage",
25
+ "TextMessagePart",
26
+ "UserMediaContentType",
27
+ "UserMessage",
28
+ ]