goose-py 0.9.15__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- goose/_internal/agent.py +184 -119
- goose/_internal/conversation.py +23 -6
- goose/_internal/result.py +12 -1
- goose/_internal/state.py +25 -19
- goose/_internal/task.py +37 -35
- goose/_internal/types/telemetry.py +113 -0
- {goose_py-0.9.15.dist-info → goose_py-0.10.0.dist-info}/METADATA +1 -1
- {goose_py-0.9.15.dist-info → goose_py-0.10.0.dist-info}/RECORD +9 -8
- {goose_py-0.9.15.dist-info → goose_py-0.10.0.dist-info}/WHEEL +0 -0
goose/_internal/agent.py
CHANGED
@@ -1,119 +1,15 @@
|
|
1
|
-
import json
|
2
1
|
import logging
|
3
2
|
from datetime import datetime
|
4
|
-
from typing import Any,
|
3
|
+
from typing import Any, Literal, Protocol, overload
|
5
4
|
|
6
5
|
from litellm import acompletion
|
7
|
-
from pydantic import
|
8
|
-
|
9
|
-
from .
|
10
|
-
from .
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
run_id: str
|
15
|
-
flow_name: str
|
16
|
-
task_name: str
|
17
|
-
model: str
|
18
|
-
system_message: str
|
19
|
-
input_messages: list[str]
|
20
|
-
output_message: str
|
21
|
-
input_cost: float
|
22
|
-
output_cost: float
|
23
|
-
total_cost: float
|
24
|
-
input_tokens: int
|
25
|
-
output_tokens: int
|
26
|
-
start_time: datetime
|
27
|
-
end_time: datetime
|
28
|
-
duration_ms: int
|
29
|
-
|
30
|
-
|
31
|
-
class AgentResponse[R: BaseModel | str](BaseModel):
|
32
|
-
INPUT_CENTS_PER_MILLION_TOKENS: ClassVar[dict[AIModel, float]] = {
|
33
|
-
AIModel.VERTEX_FLASH_8B: 30,
|
34
|
-
AIModel.VERTEX_FLASH: 15,
|
35
|
-
AIModel.VERTEX_PRO: 500,
|
36
|
-
AIModel.GEMINI_FLASH_8B: 30,
|
37
|
-
AIModel.GEMINI_FLASH: 15,
|
38
|
-
AIModel.GEMINI_PRO: 500,
|
39
|
-
}
|
40
|
-
OUTPUT_CENTS_PER_MILLION_TOKENS: ClassVar[dict[AIModel, float]] = {
|
41
|
-
AIModel.VERTEX_FLASH_8B: 30,
|
42
|
-
AIModel.VERTEX_FLASH: 15,
|
43
|
-
AIModel.VERTEX_PRO: 500,
|
44
|
-
AIModel.GEMINI_FLASH_8B: 30,
|
45
|
-
AIModel.GEMINI_FLASH: 15,
|
46
|
-
AIModel.GEMINI_PRO: 500,
|
47
|
-
}
|
48
|
-
|
49
|
-
response: R
|
50
|
-
run_id: str
|
51
|
-
flow_name: str
|
52
|
-
task_name: str
|
53
|
-
model: AIModel
|
54
|
-
system: SystemMessage | None = None
|
55
|
-
input_messages: list[UserMessage | AssistantMessage]
|
56
|
-
input_tokens: int
|
57
|
-
output_tokens: int
|
58
|
-
start_time: datetime
|
59
|
-
end_time: datetime
|
60
|
-
|
61
|
-
@computed_field
|
62
|
-
@property
|
63
|
-
def duration_ms(self) -> int:
|
64
|
-
return int((self.end_time - self.start_time).total_seconds() * 1000)
|
65
|
-
|
66
|
-
@computed_field
|
67
|
-
@property
|
68
|
-
def input_cost(self) -> float:
|
69
|
-
return self.INPUT_CENTS_PER_MILLION_TOKENS[self.model] * self.input_tokens / 1_000_000
|
70
|
-
|
71
|
-
@computed_field
|
72
|
-
@property
|
73
|
-
def output_cost(self) -> float:
|
74
|
-
return self.OUTPUT_CENTS_PER_MILLION_TOKENS[self.model] * self.output_tokens / 1_000_000
|
75
|
-
|
76
|
-
@computed_field
|
77
|
-
@property
|
78
|
-
def total_cost(self) -> float:
|
79
|
-
return self.input_cost + self.output_cost
|
80
|
-
|
81
|
-
def minimized_dump(self) -> AgentResponseDump:
|
82
|
-
if self.system is None:
|
83
|
-
minimized_system_message = ""
|
84
|
-
else:
|
85
|
-
minimized_system_message = self.system.render()
|
86
|
-
for part in minimized_system_message["content"]:
|
87
|
-
if part["type"] == "image_url":
|
88
|
-
part["image_url"] = "__MEDIA__"
|
89
|
-
minimized_system_message = json.dumps(minimized_system_message)
|
90
|
-
|
91
|
-
minimized_input_messages = [message.render() for message in self.input_messages]
|
92
|
-
for message in minimized_input_messages:
|
93
|
-
for part in message["content"]:
|
94
|
-
if part["type"] == "image_url":
|
95
|
-
part["image_url"] = "__MEDIA__"
|
96
|
-
minimized_input_messages = [json.dumps(message) for message in minimized_input_messages]
|
97
|
-
|
98
|
-
output_message = self.response.model_dump_json() if isinstance(self.response, BaseModel) else self.response
|
99
|
-
|
100
|
-
return {
|
101
|
-
"run_id": self.run_id,
|
102
|
-
"flow_name": self.flow_name,
|
103
|
-
"task_name": self.task_name,
|
104
|
-
"model": self.model.value,
|
105
|
-
"system_message": minimized_system_message,
|
106
|
-
"input_messages": minimized_input_messages,
|
107
|
-
"output_message": output_message,
|
108
|
-
"input_tokens": self.input_tokens,
|
109
|
-
"output_tokens": self.output_tokens,
|
110
|
-
"input_cost": self.input_cost,
|
111
|
-
"output_cost": self.output_cost,
|
112
|
-
"total_cost": self.total_cost,
|
113
|
-
"start_time": self.start_time,
|
114
|
-
"end_time": self.end_time,
|
115
|
-
"duration_ms": self.duration_ms,
|
116
|
-
}
|
6
|
+
from pydantic import ValidationError
|
7
|
+
|
8
|
+
from goose._internal.types.telemetry import AgentResponse
|
9
|
+
from goose.errors import Honk
|
10
|
+
|
11
|
+
from .result import FindReplaceResponse, Result, TextResult
|
12
|
+
from .types.agent import AIModel, LLMMessage
|
117
13
|
|
118
14
|
|
119
15
|
class IAgentLogger(Protocol):
|
@@ -132,27 +28,26 @@ class Agent:
|
|
132
28
|
self.run_id = run_id
|
133
29
|
self.logger = logger
|
134
30
|
|
135
|
-
async def
|
31
|
+
async def generate[R: Result](
|
136
32
|
self,
|
137
33
|
*,
|
138
|
-
messages: list[
|
34
|
+
messages: list[LLMMessage],
|
139
35
|
model: AIModel,
|
140
36
|
task_name: str,
|
141
37
|
response_model: type[R] = TextResult,
|
142
|
-
system:
|
38
|
+
system: LLMMessage | None = None,
|
143
39
|
) -> R:
|
144
40
|
start_time = datetime.now()
|
145
|
-
rendered_messages = [message.render() for message in messages]
|
146
41
|
if system is not None:
|
147
|
-
|
42
|
+
messages.insert(0, system)
|
148
43
|
|
149
44
|
if response_model is TextResult:
|
150
|
-
response = await acompletion(model=model.value, messages=
|
45
|
+
response = await acompletion(model=model.value, messages=messages)
|
151
46
|
parsed_response = response_model.model_validate({"text": response.choices[0].message.content})
|
152
47
|
else:
|
153
48
|
response = await acompletion(
|
154
49
|
model=model.value,
|
155
|
-
messages=
|
50
|
+
messages=messages,
|
156
51
|
response_format=response_model,
|
157
52
|
)
|
158
53
|
parsed_response = response_model.model_validate_json(response.choices[0].message.content)
|
@@ -178,3 +73,173 @@ class Agent:
|
|
178
73
|
logging.info(agent_response.model_dump())
|
179
74
|
|
180
75
|
return parsed_response
|
76
|
+
|
77
|
+
async def ask(
|
78
|
+
self, *, messages: list[LLMMessage], model: AIModel, task_name: str, system: LLMMessage | None = None
|
79
|
+
) -> str:
|
80
|
+
start_time = datetime.now()
|
81
|
+
|
82
|
+
if system is not None:
|
83
|
+
messages.insert(0, system)
|
84
|
+
response = await acompletion(model=model.value, messages=messages)
|
85
|
+
|
86
|
+
end_time = datetime.now()
|
87
|
+
agent_response = AgentResponse(
|
88
|
+
response=response.choices[0].message.content,
|
89
|
+
run_id=self.run_id,
|
90
|
+
flow_name=self.flow_name,
|
91
|
+
task_name=task_name,
|
92
|
+
model=model,
|
93
|
+
system=system,
|
94
|
+
input_messages=messages,
|
95
|
+
input_tokens=response.usage.prompt_tokens,
|
96
|
+
output_tokens=response.usage.completion_tokens,
|
97
|
+
start_time=start_time,
|
98
|
+
end_time=end_time,
|
99
|
+
)
|
100
|
+
|
101
|
+
if self.logger is not None:
|
102
|
+
await self.logger(response=agent_response)
|
103
|
+
else:
|
104
|
+
logging.info(agent_response.model_dump())
|
105
|
+
|
106
|
+
return response.choices[0].message.content
|
107
|
+
|
108
|
+
async def refine[R: Result](
|
109
|
+
self,
|
110
|
+
*,
|
111
|
+
messages: list[LLMMessage],
|
112
|
+
model: AIModel,
|
113
|
+
task_name: str,
|
114
|
+
response_model: type[R],
|
115
|
+
system: LLMMessage | None = None,
|
116
|
+
) -> R:
|
117
|
+
start_time = datetime.now()
|
118
|
+
|
119
|
+
if system is not None:
|
120
|
+
messages.insert(0, system)
|
121
|
+
|
122
|
+
find_replace_response = await acompletion(
|
123
|
+
model=model.value, messages=messages, response_format=FindReplaceResponse
|
124
|
+
)
|
125
|
+
parsed_find_replace_response = FindReplaceResponse.model_validate_json(
|
126
|
+
find_replace_response.choices[0].message.content
|
127
|
+
)
|
128
|
+
|
129
|
+
end_time = datetime.now()
|
130
|
+
agent_response = AgentResponse(
|
131
|
+
response=parsed_find_replace_response,
|
132
|
+
run_id=self.run_id,
|
133
|
+
flow_name=self.flow_name,
|
134
|
+
task_name=task_name,
|
135
|
+
model=model,
|
136
|
+
system=system,
|
137
|
+
input_messages=messages,
|
138
|
+
input_tokens=find_replace_response.usage.prompt_tokens,
|
139
|
+
output_tokens=find_replace_response.usage.completion_tokens,
|
140
|
+
start_time=start_time,
|
141
|
+
end_time=end_time,
|
142
|
+
)
|
143
|
+
|
144
|
+
if self.logger is not None:
|
145
|
+
await self.logger(response=agent_response)
|
146
|
+
else:
|
147
|
+
logging.info(agent_response.model_dump())
|
148
|
+
|
149
|
+
refined_response = self.__apply_find_replace(
|
150
|
+
result=self.__find_last_result(messages=messages, response_model=response_model),
|
151
|
+
find_replace_response=parsed_find_replace_response,
|
152
|
+
response_model=response_model,
|
153
|
+
)
|
154
|
+
|
155
|
+
return refined_response
|
156
|
+
|
157
|
+
@overload
|
158
|
+
async def __call__[R: Result](
|
159
|
+
self,
|
160
|
+
*,
|
161
|
+
messages: list[LLMMessage],
|
162
|
+
model: AIModel,
|
163
|
+
task_name: str,
|
164
|
+
mode: Literal["generate"],
|
165
|
+
response_model: type[R],
|
166
|
+
system: LLMMessage | None = None,
|
167
|
+
) -> R: ...
|
168
|
+
|
169
|
+
@overload
|
170
|
+
async def __call__[R: Result](
|
171
|
+
self,
|
172
|
+
*,
|
173
|
+
messages: list[LLMMessage],
|
174
|
+
model: AIModel,
|
175
|
+
task_name: str,
|
176
|
+
mode: Literal["ask"],
|
177
|
+
response_model: type[R] = TextResult,
|
178
|
+
system: LLMMessage | None = None,
|
179
|
+
) -> str: ...
|
180
|
+
|
181
|
+
@overload
|
182
|
+
async def __call__[R: Result](
|
183
|
+
self,
|
184
|
+
*,
|
185
|
+
messages: list[LLMMessage],
|
186
|
+
model: AIModel,
|
187
|
+
task_name: str,
|
188
|
+
response_model: type[R],
|
189
|
+
mode: Literal["refine"],
|
190
|
+
system: LLMMessage | None = None,
|
191
|
+
) -> R: ...
|
192
|
+
|
193
|
+
@overload
|
194
|
+
async def __call__[R: Result](
|
195
|
+
self,
|
196
|
+
*,
|
197
|
+
messages: list[LLMMessage],
|
198
|
+
model: AIModel,
|
199
|
+
task_name: str,
|
200
|
+
response_model: type[R],
|
201
|
+
system: LLMMessage | None = None,
|
202
|
+
) -> R: ...
|
203
|
+
|
204
|
+
async def __call__[R: Result](
|
205
|
+
self,
|
206
|
+
*,
|
207
|
+
messages: list[LLMMessage],
|
208
|
+
model: AIModel,
|
209
|
+
task_name: str,
|
210
|
+
response_model: type[R] = TextResult,
|
211
|
+
mode: Literal["generate", "ask", "refine"] = "generate",
|
212
|
+
system: LLMMessage | None = None,
|
213
|
+
) -> R | str:
|
214
|
+
match mode:
|
215
|
+
case "generate":
|
216
|
+
return await self.generate(
|
217
|
+
messages=messages, model=model, task_name=task_name, response_model=response_model, system=system
|
218
|
+
)
|
219
|
+
case "ask":
|
220
|
+
return await self.ask(messages=messages, model=model, task_name=task_name, system=system)
|
221
|
+
case "refine":
|
222
|
+
return await self.refine(
|
223
|
+
messages=messages, model=model, task_name=task_name, response_model=response_model, system=system
|
224
|
+
)
|
225
|
+
|
226
|
+
def __apply_find_replace[R: Result](
|
227
|
+
self, *, result: R, find_replace_response: FindReplaceResponse, response_model: type[R]
|
228
|
+
) -> R:
|
229
|
+
dumped_result = result.model_dump_json()
|
230
|
+
for replacement in find_replace_response.replacements:
|
231
|
+
dumped_result = dumped_result.replace(replacement.find, replacement.replace)
|
232
|
+
|
233
|
+
return response_model.model_validate_json(dumped_result)
|
234
|
+
|
235
|
+
def __find_last_result[R: Result](self, *, messages: list[LLMMessage], response_model: type[R]) -> R:
|
236
|
+
for message in reversed(messages):
|
237
|
+
if message["role"] == "assistant":
|
238
|
+
try:
|
239
|
+
only_part = message["content"][0]
|
240
|
+
if only_part["type"] == "text":
|
241
|
+
return response_model.model_validate_json(only_part["text"])
|
242
|
+
except ValidationError:
|
243
|
+
continue
|
244
|
+
|
245
|
+
raise Honk("No last result found, failed to refine")
|
goose/_internal/conversation.py
CHANGED
@@ -2,18 +2,20 @@ from typing import Self
|
|
2
2
|
|
3
3
|
from pydantic import BaseModel
|
4
4
|
|
5
|
+
from goose.errors import Honk
|
6
|
+
|
5
7
|
from .result import Result
|
6
8
|
from .types.agent import AssistantMessage, LLMMessage, SystemMessage, UserMessage
|
7
9
|
|
8
10
|
|
9
11
|
class Conversation[R: Result](BaseModel):
|
10
12
|
user_messages: list[UserMessage]
|
11
|
-
|
13
|
+
assistant_messages: list[R | str]
|
12
14
|
context: SystemMessage | None = None
|
13
15
|
|
14
16
|
@property
|
15
17
|
def awaiting_response(self) -> bool:
|
16
|
-
return len(self.user_messages) == len(self.
|
18
|
+
return len(self.user_messages) == len(self.assistant_messages)
|
17
19
|
|
18
20
|
def render(self) -> list[LLMMessage]:
|
19
21
|
messages: list[LLMMessage] = []
|
@@ -21,15 +23,30 @@ class Conversation[R: Result](BaseModel):
|
|
21
23
|
messages.append(self.context.render())
|
22
24
|
|
23
25
|
for message_index in range(len(self.user_messages)):
|
24
|
-
|
26
|
+
message = self.assistant_messages[message_index]
|
27
|
+
if isinstance(message, str):
|
28
|
+
messages.append(AssistantMessage(text=message).render())
|
29
|
+
else:
|
30
|
+
messages.append(AssistantMessage(text=message.model_dump_json()).render())
|
31
|
+
|
25
32
|
messages.append(self.user_messages[message_index].render())
|
26
33
|
|
27
|
-
if len(self.
|
28
|
-
|
34
|
+
if len(self.assistant_messages) > len(self.user_messages):
|
35
|
+
message = self.assistant_messages[-1]
|
36
|
+
if isinstance(message, str):
|
37
|
+
messages.append(AssistantMessage(text=message).render())
|
38
|
+
else:
|
39
|
+
messages.append(AssistantMessage(text=message.model_dump_json()).render())
|
29
40
|
|
30
41
|
return messages
|
31
42
|
|
32
43
|
def undo(self) -> Self:
|
44
|
+
if len(self.user_messages) == 0:
|
45
|
+
raise Honk("Cannot undo, no user messages")
|
46
|
+
|
47
|
+
if len(self.assistant_messages) == 0:
|
48
|
+
raise Honk("Cannot undo, no assistant messages")
|
49
|
+
|
33
50
|
self.user_messages.pop()
|
34
|
-
self.
|
51
|
+
self.assistant_messages.pop()
|
35
52
|
return self
|
goose/_internal/result.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from pydantic import BaseModel, ConfigDict
|
1
|
+
from pydantic import BaseModel, ConfigDict, Field
|
2
2
|
|
3
3
|
|
4
4
|
class Result(BaseModel):
|
@@ -7,3 +7,14 @@ class Result(BaseModel):
|
|
7
7
|
|
8
8
|
class TextResult(Result):
|
9
9
|
text: str
|
10
|
+
|
11
|
+
|
12
|
+
class Replacement(BaseModel):
|
13
|
+
find: str = Field(description="Text to find, to be replaced with `replace`")
|
14
|
+
replace: str = Field(description="Text to replace `find` with")
|
15
|
+
|
16
|
+
|
17
|
+
class FindReplaceResponse(BaseModel):
|
18
|
+
replacements: list[Replacement] = Field(
|
19
|
+
description="List of replacements to make in the previous result to satisfy the user's request"
|
20
|
+
)
|
goose/_internal/state.py
CHANGED
@@ -4,15 +4,11 @@ from typing import TYPE_CHECKING, Any, NewType, Self
|
|
4
4
|
|
5
5
|
from pydantic import BaseModel, ConfigDict
|
6
6
|
|
7
|
-
from
|
8
|
-
from .
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
UserMessage,
|
13
|
-
)
|
14
|
-
from .conversation import Conversation
|
15
|
-
from .result import Result
|
7
|
+
from goose._internal.agent import Agent, IAgentLogger
|
8
|
+
from goose._internal.conversation import Conversation
|
9
|
+
from goose._internal.result import Result
|
10
|
+
from goose._internal.types.agent import SystemMessage, UserMessage
|
11
|
+
from goose.errors import Honk
|
16
12
|
|
17
13
|
if TYPE_CHECKING:
|
18
14
|
from goose._internal.task import Task
|
@@ -32,10 +28,11 @@ class NodeState[ResultT: Result](BaseModel):
|
|
32
28
|
|
33
29
|
@property
|
34
30
|
def result(self) -> ResultT:
|
35
|
-
|
36
|
-
|
31
|
+
for message in reversed(self.conversation.assistant_messages):
|
32
|
+
if isinstance(message, Result):
|
33
|
+
return message
|
37
34
|
|
38
|
-
|
35
|
+
raise Honk("Node awaiting response, has no result")
|
39
36
|
|
40
37
|
def set_context(self, *, context: SystemMessage) -> Self:
|
41
38
|
self.conversation.context = context
|
@@ -48,24 +45,33 @@ class NodeState[ResultT: Result](BaseModel):
|
|
48
45
|
new_hash: int | None = None,
|
49
46
|
overwrite: bool = False,
|
50
47
|
) -> Self:
|
51
|
-
if overwrite and len(self.conversation.
|
52
|
-
self.conversation.
|
48
|
+
if overwrite and len(self.conversation.assistant_messages) > 0:
|
49
|
+
self.conversation.assistant_messages[-1] = result
|
53
50
|
else:
|
54
|
-
self.conversation.
|
51
|
+
self.conversation.assistant_messages.append(result)
|
55
52
|
if new_hash is not None:
|
56
53
|
self.last_hash = new_hash
|
57
54
|
return self
|
58
55
|
|
56
|
+
def add_answer(self, *, answer: str) -> Self:
|
57
|
+
self.conversation.assistant_messages.append(answer)
|
58
|
+
return self
|
59
|
+
|
59
60
|
def add_user_message(self, *, message: UserMessage) -> Self:
|
60
61
|
self.conversation.user_messages.append(message)
|
61
62
|
return self
|
62
63
|
|
63
64
|
def edit_last_result(self, *, result: ResultT) -> Self:
|
64
|
-
if len(self.conversation.
|
65
|
+
if len(self.conversation.assistant_messages) == 0:
|
65
66
|
raise Honk("Node awaiting response, has no result")
|
66
67
|
|
67
|
-
self.conversation.
|
68
|
-
|
68
|
+
for message_index, message in enumerate(reversed(self.conversation.assistant_messages)):
|
69
|
+
if isinstance(message, Result):
|
70
|
+
index = len(self.conversation.assistant_messages) - message_index - 1
|
71
|
+
self.conversation.assistant_messages[index] = result
|
72
|
+
return self
|
73
|
+
|
74
|
+
raise Honk("Node awaiting response, has no result")
|
69
75
|
|
70
76
|
def undo(self) -> Self:
|
71
77
|
self.conversation.undo()
|
@@ -117,7 +123,7 @@ class FlowRun[FlowArgumentsT: FlowArguments]:
|
|
117
123
|
return NodeState[task.result_type](
|
118
124
|
task_name=task.name,
|
119
125
|
index=index,
|
120
|
-
conversation=Conversation[task.result_type](user_messages=[],
|
126
|
+
conversation=Conversation[task.result_type](user_messages=[], assistant_messages=[]),
|
121
127
|
last_hash=0,
|
122
128
|
)
|
123
129
|
|
goose/_internal/task.py
CHANGED
@@ -5,11 +5,10 @@ from typing import Any, overload
|
|
5
5
|
from pydantic import BaseModel
|
6
6
|
|
7
7
|
from ..errors import Honk
|
8
|
-
from .agent import Agent, AIModel
|
9
|
-
from .
|
10
|
-
from .result import Result, TextResult
|
8
|
+
from .agent import Agent, AIModel
|
9
|
+
from .result import Result
|
11
10
|
from .state import FlowRun, NodeState, get_current_flow_run
|
12
|
-
from .types.agent import
|
11
|
+
from .types.agent import SystemMessage, UserMessage
|
13
12
|
|
14
13
|
|
15
14
|
class Task[**P, R: Result]:
|
@@ -19,12 +18,11 @@ class Task[**P, R: Result]:
|
|
19
18
|
/,
|
20
19
|
*,
|
21
20
|
retries: int = 0,
|
22
|
-
|
21
|
+
refinement_model: AIModel = AIModel.GEMINI_FLASH,
|
23
22
|
) -> None:
|
24
23
|
self._generator = generator
|
25
24
|
self._retries = retries
|
26
|
-
self.
|
27
|
-
self._adapter_model = adapter_model
|
25
|
+
self._refinement_model = refinement_model
|
28
26
|
|
29
27
|
@property
|
30
28
|
def result_type(self) -> type[R]:
|
@@ -46,6 +44,26 @@ class Task[**P, R: Result]:
|
|
46
44
|
else:
|
47
45
|
return state.result
|
48
46
|
|
47
|
+
async def ask(self, *, user_message: UserMessage, context: SystemMessage | None = None, index: int = 0) -> str:
|
48
|
+
flow_run = self.__get_current_flow_run()
|
49
|
+
node_state = flow_run.get(task=self, index=index)
|
50
|
+
|
51
|
+
if len(node_state.conversation.assistant_messages) == 0:
|
52
|
+
raise Honk("Cannot ask about a task that has not been initially generated")
|
53
|
+
|
54
|
+
node_state.add_user_message(message=user_message)
|
55
|
+
answer = await flow_run.agent(
|
56
|
+
messages=node_state.conversation.render(),
|
57
|
+
model=self._refinement_model,
|
58
|
+
task_name=f"ask--{self.name}",
|
59
|
+
system=context.render() if context is not None else None,
|
60
|
+
mode="ask",
|
61
|
+
)
|
62
|
+
node_state.add_answer(answer=answer)
|
63
|
+
flow_run.upsert_node_state(node_state)
|
64
|
+
|
65
|
+
return answer
|
66
|
+
|
49
67
|
async def refine(
|
50
68
|
self,
|
51
69
|
*,
|
@@ -56,14 +74,20 @@ class Task[**P, R: Result]:
|
|
56
74
|
flow_run = self.__get_current_flow_run()
|
57
75
|
node_state = flow_run.get(task=self, index=index)
|
58
76
|
|
59
|
-
if len(node_state.conversation.
|
77
|
+
if len(node_state.conversation.assistant_messages) == 0:
|
60
78
|
raise Honk("Cannot refine a task that has not been initially generated")
|
61
79
|
|
62
80
|
if context is not None:
|
63
81
|
node_state.set_context(context=context)
|
64
82
|
node_state.add_user_message(message=user_message)
|
65
83
|
|
66
|
-
result = await
|
84
|
+
result = await flow_run.agent(
|
85
|
+
messages=node_state.conversation.render(),
|
86
|
+
model=self._refinement_model,
|
87
|
+
task_name=f"refine--{self.name}",
|
88
|
+
system=context.render() if context is not None else None,
|
89
|
+
response_model=self.result_type,
|
90
|
+
)
|
67
91
|
node_state.add_result(result=result)
|
68
92
|
flow_run.upsert_node_state(node_state)
|
69
93
|
|
@@ -88,28 +112,6 @@ class Task[**P, R: Result]:
|
|
88
112
|
flow_run.upsert_node_state(node_state)
|
89
113
|
return result
|
90
114
|
|
91
|
-
async def __adapt(self, *, conversation: Conversation[R], agent: Agent) -> R:
|
92
|
-
messages: list[UserMessage | AssistantMessage] = []
|
93
|
-
for message_index in range(len(conversation.user_messages)):
|
94
|
-
user_message = conversation.user_messages[message_index]
|
95
|
-
result = conversation.result_messages[message_index]
|
96
|
-
|
97
|
-
if isinstance(result, TextResult):
|
98
|
-
assistant_text = result.text
|
99
|
-
else:
|
100
|
-
assistant_text = result.model_dump_json()
|
101
|
-
assistant_message = AssistantMessage(text=assistant_text)
|
102
|
-
messages.append(assistant_message)
|
103
|
-
messages.append(user_message)
|
104
|
-
|
105
|
-
return await agent(
|
106
|
-
messages=messages,
|
107
|
-
model=self._adapter_model,
|
108
|
-
task_name=f"adapt--{self.name}",
|
109
|
-
system=conversation.context,
|
110
|
-
response_model=self.result_type,
|
111
|
-
)
|
112
|
-
|
113
115
|
def __hash_task_call(self, *args: P.args, **kwargs: P.kwargs) -> int:
|
114
116
|
def update_hash(argument: Any, current_hash: Any = hashlib.sha256()) -> None:
|
115
117
|
try:
|
@@ -148,20 +150,20 @@ class Task[**P, R: Result]:
|
|
148
150
|
def task[**P, R: Result](generator: Callable[P, Awaitable[R]], /) -> Task[P, R]: ...
|
149
151
|
@overload
|
150
152
|
def task[**P, R: Result](
|
151
|
-
*, retries: int = 0,
|
153
|
+
*, retries: int = 0, refinement_model: AIModel = AIModel.GEMINI_FLASH
|
152
154
|
) -> Callable[[Callable[P, Awaitable[R]]], Task[P, R]]: ...
|
153
155
|
def task[**P, R: Result](
|
154
156
|
generator: Callable[P, Awaitable[R]] | None = None,
|
155
157
|
/,
|
156
158
|
*,
|
157
159
|
retries: int = 0,
|
158
|
-
|
160
|
+
refinement_model: AIModel = AIModel.GEMINI_FLASH,
|
159
161
|
) -> Task[P, R] | Callable[[Callable[P, Awaitable[R]]], Task[P, R]]:
|
160
162
|
if generator is None:
|
161
163
|
|
162
164
|
def decorator(fn: Callable[P, Awaitable[R]]) -> Task[P, R]:
|
163
|
-
return Task(fn, retries=retries,
|
165
|
+
return Task(fn, retries=retries, refinement_model=refinement_model)
|
164
166
|
|
165
167
|
return decorator
|
166
168
|
|
167
|
-
return Task(generator, retries=retries,
|
169
|
+
return Task(generator, retries=retries, refinement_model=refinement_model)
|
@@ -0,0 +1,113 @@
|
|
1
|
+
import json
|
2
|
+
from datetime import datetime
|
3
|
+
from typing import ClassVar, TypedDict
|
4
|
+
|
5
|
+
from pydantic import BaseModel, computed_field
|
6
|
+
|
7
|
+
from ..types.agent import AIModel, LLMMessage
|
8
|
+
|
9
|
+
|
10
|
+
class AgentResponseDump(TypedDict):
|
11
|
+
run_id: str
|
12
|
+
flow_name: str
|
13
|
+
task_name: str
|
14
|
+
model: str
|
15
|
+
system_message: str
|
16
|
+
input_messages: list[str]
|
17
|
+
output_message: str
|
18
|
+
input_cost: float
|
19
|
+
output_cost: float
|
20
|
+
total_cost: float
|
21
|
+
input_tokens: int
|
22
|
+
output_tokens: int
|
23
|
+
start_time: datetime
|
24
|
+
end_time: datetime
|
25
|
+
duration_ms: int
|
26
|
+
|
27
|
+
|
28
|
+
class AgentResponse[R: BaseModel | str](BaseModel):
|
29
|
+
INPUT_DOLLARS_PER_MILLION_TOKENS: ClassVar[dict[AIModel, float]] = {
|
30
|
+
AIModel.VERTEX_FLASH_8B: 0.30,
|
31
|
+
AIModel.VERTEX_FLASH: 0.15,
|
32
|
+
AIModel.VERTEX_PRO: 5.00,
|
33
|
+
AIModel.GEMINI_FLASH_8B: 0.30,
|
34
|
+
AIModel.GEMINI_FLASH: 0.15,
|
35
|
+
AIModel.GEMINI_PRO: 5.00,
|
36
|
+
}
|
37
|
+
OUTPUT_DOLLARS_PER_MILLION_TOKENS: ClassVar[dict[AIModel, float]] = {
|
38
|
+
AIModel.VERTEX_FLASH_8B: 0.30,
|
39
|
+
AIModel.VERTEX_FLASH: 0.15,
|
40
|
+
AIModel.VERTEX_PRO: 5.00,
|
41
|
+
AIModel.GEMINI_FLASH_8B: 0.30,
|
42
|
+
AIModel.GEMINI_FLASH: 0.15,
|
43
|
+
AIModel.GEMINI_PRO: 5.00,
|
44
|
+
}
|
45
|
+
|
46
|
+
response: R
|
47
|
+
run_id: str
|
48
|
+
flow_name: str
|
49
|
+
task_name: str
|
50
|
+
model: AIModel
|
51
|
+
system: LLMMessage | None = None
|
52
|
+
input_messages: list[LLMMessage]
|
53
|
+
input_tokens: int
|
54
|
+
output_tokens: int
|
55
|
+
start_time: datetime
|
56
|
+
end_time: datetime
|
57
|
+
|
58
|
+
@computed_field
|
59
|
+
@property
|
60
|
+
def duration_ms(self) -> int:
|
61
|
+
return int((self.end_time - self.start_time).total_seconds() * 1000)
|
62
|
+
|
63
|
+
@computed_field
|
64
|
+
@property
|
65
|
+
def input_cost(self) -> float:
|
66
|
+
return self.INPUT_DOLLARS_PER_MILLION_TOKENS[self.model] * self.input_tokens / 1_000_000
|
67
|
+
|
68
|
+
@computed_field
|
69
|
+
@property
|
70
|
+
def output_cost(self) -> float:
|
71
|
+
return self.OUTPUT_DOLLARS_PER_MILLION_TOKENS[self.model] * self.output_tokens / 1_000_000
|
72
|
+
|
73
|
+
@computed_field
|
74
|
+
@property
|
75
|
+
def total_cost(self) -> float:
|
76
|
+
return self.input_cost + self.output_cost
|
77
|
+
|
78
|
+
def minimized_dump(self) -> AgentResponseDump:
|
79
|
+
if self.system is None:
|
80
|
+
minimized_system_message = ""
|
81
|
+
else:
|
82
|
+
minimized_system_message = self.system
|
83
|
+
for part in minimized_system_message["content"]:
|
84
|
+
if part["type"] == "image_url":
|
85
|
+
part["image_url"] = "__MEDIA__"
|
86
|
+
minimized_system_message = json.dumps(minimized_system_message)
|
87
|
+
|
88
|
+
minimized_input_messages = [message for message in self.input_messages]
|
89
|
+
for message in minimized_input_messages:
|
90
|
+
for part in message["content"]:
|
91
|
+
if part["type"] == "image_url":
|
92
|
+
part["image_url"] = "__MEDIA__"
|
93
|
+
minimized_input_messages = [json.dumps(message) for message in minimized_input_messages]
|
94
|
+
|
95
|
+
output_message = self.response.model_dump_json() if isinstance(self.response, BaseModel) else self.response
|
96
|
+
|
97
|
+
return {
|
98
|
+
"run_id": self.run_id,
|
99
|
+
"flow_name": self.flow_name,
|
100
|
+
"task_name": self.task_name,
|
101
|
+
"model": self.model.value,
|
102
|
+
"system_message": minimized_system_message,
|
103
|
+
"input_messages": minimized_input_messages,
|
104
|
+
"output_message": output_message,
|
105
|
+
"input_tokens": self.input_tokens,
|
106
|
+
"output_tokens": self.output_tokens,
|
107
|
+
"input_cost": self.input_cost,
|
108
|
+
"output_cost": self.output_cost,
|
109
|
+
"total_cost": self.total_cost,
|
110
|
+
"start_time": self.start_time,
|
111
|
+
"end_time": self.end_time,
|
112
|
+
"duration_ms": self.duration_ms,
|
113
|
+
}
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: goose-py
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.10.0
|
4
4
|
Summary: A tool for AI workflows based on human-computer collaboration and structured output.
|
5
5
|
Author-email: Nash Taylor <nash@chelle.ai>, Joshua Cook <joshua@chelle.ai>, Michael Sankur <michael@chelle.ai>
|
6
6
|
Requires-Python: >=3.12
|
@@ -5,15 +5,16 @@ goose/flow.py,sha256=YsZLBa5I1W27_P6LYGWbtFX8ZYx9vJG3KtENYChHm5E,111
|
|
5
5
|
goose/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
6
|
goose/runs.py,sha256=ub-r_gzbUbaIzWXX-jc-dncNxEh6zTfzIkmnDfCSbRI,160
|
7
7
|
goose/task.py,sha256=95rspdxETJoY12IHBl3KjnVIdqQnf1jDKlnGWNWOTvQ,53
|
8
|
-
goose/_internal/agent.py,sha256=
|
9
|
-
goose/_internal/conversation.py,sha256=
|
8
|
+
goose/_internal/agent.py,sha256=v_nzQLM_vyKWOc-g1Ke8TwlThkThjraE4HFHNywHMv0,7915
|
9
|
+
goose/_internal/conversation.py,sha256=zvKqLxJSCIIuhD7gjcSFhleYsLabu-ALl9woWFy3mQU,1766
|
10
10
|
goose/_internal/flow.py,sha256=RShMsxgt49g1fZJ3rlwDHtI1j39lZzewx8hZ7DGN5kg,4124
|
11
|
-
goose/_internal/result.py,sha256
|
12
|
-
goose/_internal/state.py,sha256=
|
11
|
+
goose/_internal/result.py,sha256=vtJMfBxb9skfl8st2tn4hBmEq6qmXiJTme_B5QTgu2M,538
|
12
|
+
goose/_internal/state.py,sha256=U4gM0K4MAlRFTpqenCYHX9TYGuhWVKIfa4yBeZ9Qc9s,7090
|
13
13
|
goose/_internal/store.py,sha256=tWmKfa1-yq1jU6lT3l6kSOmVt2m3H7I1xLMTrxnUDI8,889
|
14
|
-
goose/_internal/task.py,sha256=
|
14
|
+
goose/_internal/task.py,sha256=XObr-fX9oH3TjgxHXQAUpKR2Zvup91uWSpfBeBbUJbU,6225
|
15
15
|
goose/_internal/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
16
16
|
goose/_internal/types/agent.py,sha256=g0KD-aPWZlUGBx72AwQd3LeniFxHATeflZ7191QjFZA,2696
|
17
|
-
|
18
|
-
goose_py-0.
|
19
|
-
goose_py-0.
|
17
|
+
goose/_internal/types/telemetry.py,sha256=7zeqyDDxf95puirNM6Gr9VFuxoDshXcV1__V0tiMswE,3663
|
18
|
+
goose_py-0.10.0.dist-info/METADATA,sha256=ct-j5aAGBQsa1BNBKR00bj5zitPsNyPcjrDchAWqH54,442
|
19
|
+
goose_py-0.10.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
20
|
+
goose_py-0.10.0.dist-info/RECORD,,
|
File without changes
|