goose-py 0.4.0__tar.gz → 0.4.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {goose_py-0.4.0 → goose_py-0.4.1}/PKG-INFO +1 -1
- {goose_py-0.4.0 → goose_py-0.4.1}/goose/agent.py +30 -20
- {goose_py-0.4.0 → goose_py-0.4.1}/goose/flow.py +2 -5
- goose_py-0.4.1/goose/result.py +9 -0
- {goose_py-0.4.0 → goose_py-0.4.1}/pyproject.toml +1 -1
- {goose_py-0.4.0 → goose_py-0.4.1}/README.md +0 -0
- {goose_py-0.4.0 → goose_py-0.4.1}/goose/__init__.py +0 -0
- {goose_py-0.4.0 → goose_py-0.4.1}/goose/errors.py +0 -0
- {goose_py-0.4.0 → goose_py-0.4.1}/goose/py.typed +0 -0
- {goose_py-0.4.0 → goose_py-0.4.1}/goose/store.py +0 -0
@@ -7,6 +7,7 @@ from typing import Any, ClassVar, Literal, NotRequired, Protocol, TypedDict
|
|
7
7
|
|
8
8
|
from litellm import acompletion
|
9
9
|
from pydantic import BaseModel, computed_field
|
10
|
+
from goose.result import Result, TextResult
|
10
11
|
|
11
12
|
|
12
13
|
class GeminiModel(StrEnum):
|
@@ -115,7 +116,7 @@ class AgentResponseDump(TypedDict):
|
|
115
116
|
duration_ms: int
|
116
117
|
|
117
118
|
|
118
|
-
class AgentResponse[R: BaseModel](BaseModel):
|
119
|
+
class AgentResponse[R: BaseModel | str](BaseModel):
|
119
120
|
INPUT_CENTS_PER_MILLION_TOKENS: ClassVar[dict[GeminiModel, float]] = {
|
120
121
|
GeminiModel.FLASH_8B: 30,
|
121
122
|
GeminiModel.FLASH: 15,
|
@@ -186,6 +187,12 @@ class AgentResponse[R: BaseModel](BaseModel):
|
|
186
187
|
json.dumps(message) for message in minimized_input_messages
|
187
188
|
]
|
188
189
|
|
190
|
+
output_message = (
|
191
|
+
self.response.model_dump_json()
|
192
|
+
if isinstance(self.response, BaseModel)
|
193
|
+
else self.response
|
194
|
+
)
|
195
|
+
|
189
196
|
return {
|
190
197
|
"run_id": self.run_id,
|
191
198
|
"flow_name": self.flow_name,
|
@@ -193,7 +200,7 @@ class AgentResponse[R: BaseModel](BaseModel):
|
|
193
200
|
"model": self.model.value,
|
194
201
|
"system_message": minimized_system_message,
|
195
202
|
"input_messages": minimized_input_messages,
|
196
|
-
"output_message":
|
203
|
+
"output_message": output_message,
|
197
204
|
"input_tokens": self.input_tokens,
|
198
205
|
"output_tokens": self.output_tokens,
|
199
206
|
"input_cost": self.input_cost,
|
@@ -221,13 +228,13 @@ class Agent:
|
|
221
228
|
self.run_id = run_id
|
222
229
|
self.logger = logger
|
223
230
|
|
224
|
-
async def __call__[R:
|
231
|
+
async def __call__[R: Result](
|
225
232
|
self,
|
226
233
|
*,
|
227
234
|
messages: list[UserMessage | AssistantMessage],
|
228
235
|
model: GeminiModel,
|
229
|
-
response_model: type[R],
|
230
236
|
task_name: str,
|
237
|
+
response_model: type[R] = TextResult,
|
231
238
|
system: SystemMessage | None = None,
|
232
239
|
) -> R:
|
233
240
|
start_time = datetime.now()
|
@@ -235,22 +242,25 @@ class Agent:
|
|
235
242
|
if system is not None:
|
236
243
|
rendered_messages.insert(0, system.render())
|
237
244
|
|
238
|
-
|
239
|
-
model=model.value,
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
245
|
+
if response_model is TextResult:
|
246
|
+
response = await acompletion(model=model.value, messages=rendered_messages)
|
247
|
+
parsed_response = response_model.model_validate(
|
248
|
+
{"text": response.choices[0].message.content}
|
249
|
+
)
|
250
|
+
else:
|
251
|
+
response = await acompletion(
|
252
|
+
model=model.value,
|
253
|
+
messages=rendered_messages,
|
254
|
+
response_format={
|
255
|
+
"type": "json_object",
|
256
|
+
"response_schema": response_model.model_json_schema(),
|
257
|
+
"enforce_validation": True,
|
258
|
+
},
|
259
|
+
)
|
260
|
+
parsed_response = response_model.model_validate_json(
|
261
|
+
response.choices[0].message.content
|
262
|
+
)
|
250
263
|
|
251
|
-
parsed_response = response_model.model_validate_json(
|
252
|
-
response.choices[0].message.content
|
253
|
-
)
|
254
264
|
end_time = datetime.now()
|
255
265
|
agent_response = AgentResponse(
|
256
266
|
response=parsed_response,
|
@@ -271,4 +281,4 @@ class Agent:
|
|
271
281
|
else:
|
272
282
|
logging.info(agent_response.model_dump())
|
273
283
|
|
274
|
-
return
|
284
|
+
return parsed_response
|
@@ -13,7 +13,7 @@ from typing import (
|
|
13
13
|
overload,
|
14
14
|
)
|
15
15
|
|
16
|
-
from pydantic import BaseModel
|
16
|
+
from pydantic import BaseModel
|
17
17
|
|
18
18
|
from goose.agent import (
|
19
19
|
Agent,
|
@@ -25,14 +25,11 @@ from goose.agent import (
|
|
25
25
|
)
|
26
26
|
from goose.errors import Honk
|
27
27
|
from goose.store import IFlowRunStore, InMemoryFlowRunStore
|
28
|
+
from goose.result import Result
|
28
29
|
|
29
30
|
SerializedFlowRun = NewType("SerializedFlowRun", str)
|
30
31
|
|
31
32
|
|
32
|
-
class Result(BaseModel):
|
33
|
-
model_config = ConfigDict(frozen=True)
|
34
|
-
|
35
|
-
|
36
33
|
class Conversation[R: Result](BaseModel):
|
37
34
|
user_messages: list[UserMessage]
|
38
35
|
result_messages: list[R]
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|