goose-py 0.3.12__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
goose/agent.py
CHANGED
@@ -7,6 +7,7 @@ from typing import Any, ClassVar, Literal, NotRequired, Protocol, TypedDict
|
|
7
7
|
|
8
8
|
from litellm import acompletion
|
9
9
|
from pydantic import BaseModel, computed_field
|
10
|
+
from goose.result import Result, TextResult
|
10
11
|
|
11
12
|
|
12
13
|
class GeminiModel(StrEnum):
|
@@ -115,7 +116,7 @@ class AgentResponseDump(TypedDict):
|
|
115
116
|
duration_ms: int
|
116
117
|
|
117
118
|
|
118
|
-
class AgentResponse[R: BaseModel](BaseModel):
|
119
|
+
class AgentResponse[R: BaseModel | str](BaseModel):
|
119
120
|
INPUT_CENTS_PER_MILLION_TOKENS: ClassVar[dict[GeminiModel, float]] = {
|
120
121
|
GeminiModel.FLASH_8B: 30,
|
121
122
|
GeminiModel.FLASH: 15,
|
@@ -186,6 +187,12 @@ class AgentResponse[R: BaseModel](BaseModel):
|
|
186
187
|
json.dumps(message) for message in minimized_input_messages
|
187
188
|
]
|
188
189
|
|
190
|
+
output_message = (
|
191
|
+
self.response.model_dump_json()
|
192
|
+
if isinstance(self.response, BaseModel)
|
193
|
+
else self.response
|
194
|
+
)
|
195
|
+
|
189
196
|
return {
|
190
197
|
"run_id": self.run_id,
|
191
198
|
"flow_name": self.flow_name,
|
@@ -193,7 +200,7 @@ class AgentResponse[R: BaseModel](BaseModel):
|
|
193
200
|
"model": self.model.value,
|
194
201
|
"system_message": minimized_system_message,
|
195
202
|
"input_messages": minimized_input_messages,
|
196
|
-
"output_message":
|
203
|
+
"output_message": output_message,
|
197
204
|
"input_tokens": self.input_tokens,
|
198
205
|
"output_tokens": self.output_tokens,
|
199
206
|
"input_cost": self.input_cost,
|
@@ -221,13 +228,13 @@ class Agent:
|
|
221
228
|
self.run_id = run_id
|
222
229
|
self.logger = logger
|
223
230
|
|
224
|
-
async def __call__[R:
|
231
|
+
async def __call__[R: Result](
|
225
232
|
self,
|
226
233
|
*,
|
227
234
|
messages: list[UserMessage | AssistantMessage],
|
228
235
|
model: GeminiModel,
|
229
|
-
response_model: type[R],
|
230
236
|
task_name: str,
|
237
|
+
response_model: type[R] = TextResult,
|
231
238
|
system: SystemMessage | None = None,
|
232
239
|
) -> R:
|
233
240
|
start_time = datetime.now()
|
@@ -235,22 +242,25 @@ class Agent:
|
|
235
242
|
if system is not None:
|
236
243
|
rendered_messages.insert(0, system.render())
|
237
244
|
|
238
|
-
|
239
|
-
model=model.value,
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
245
|
+
if response_model is TextResult:
|
246
|
+
response = await acompletion(model=model.value, messages=rendered_messages)
|
247
|
+
parsed_response = response_model.model_validate(
|
248
|
+
{"text": response.choices[0].message.content}
|
249
|
+
)
|
250
|
+
else:
|
251
|
+
response = await acompletion(
|
252
|
+
model=model.value,
|
253
|
+
messages=rendered_messages,
|
254
|
+
response_format={
|
255
|
+
"type": "json_object",
|
256
|
+
"response_schema": response_model.model_json_schema(),
|
257
|
+
"enforce_validation": True,
|
258
|
+
},
|
259
|
+
)
|
260
|
+
parsed_response = response_model.model_validate_json(
|
261
|
+
response.choices[0].message.content
|
262
|
+
)
|
250
263
|
|
251
|
-
parsed_response = response_model.model_validate_json(
|
252
|
-
response.choices[0].message.content
|
253
|
-
)
|
254
264
|
end_time = datetime.now()
|
255
265
|
agent_response = AgentResponse(
|
256
266
|
response=parsed_response,
|
@@ -271,4 +281,4 @@ class Agent:
|
|
271
281
|
else:
|
272
282
|
logging.info(agent_response.model_dump())
|
273
283
|
|
274
|
-
return
|
284
|
+
return parsed_response
|
goose/flow.py
CHANGED
@@ -13,7 +13,7 @@ from typing import (
|
|
13
13
|
overload,
|
14
14
|
)
|
15
15
|
|
16
|
-
from pydantic import BaseModel
|
16
|
+
from pydantic import BaseModel
|
17
17
|
|
18
18
|
from goose.agent import (
|
19
19
|
Agent,
|
@@ -25,14 +25,11 @@ from goose.agent import (
|
|
25
25
|
)
|
26
26
|
from goose.errors import Honk
|
27
27
|
from goose.store import IFlowRunStore, InMemoryFlowRunStore
|
28
|
+
from goose.result import Result
|
28
29
|
|
29
30
|
SerializedFlowRun = NewType("SerializedFlowRun", str)
|
30
31
|
|
31
32
|
|
32
|
-
class Result(BaseModel):
|
33
|
-
model_config = ConfigDict(frozen=True)
|
34
|
-
|
35
|
-
|
36
33
|
class Conversation[R: Result](BaseModel):
|
37
34
|
user_messages: list[UserMessage]
|
38
35
|
result_messages: list[R]
|
@@ -115,6 +112,8 @@ class FlowRun:
|
|
115
112
|
self._flow_name = ""
|
116
113
|
self._id = ""
|
117
114
|
self._agent: Agent | None = None
|
115
|
+
self._flow_args: tuple[Any, ...] | None = None
|
116
|
+
self._flow_kwargs: dict[str, Any] | None = None
|
118
117
|
|
119
118
|
@property
|
120
119
|
def flow_name(self) -> str:
|
@@ -130,17 +129,12 @@ class FlowRun:
|
|
130
129
|
raise Honk("Agent is only accessible once a run is started")
|
131
130
|
return self._agent
|
132
131
|
|
133
|
-
|
134
|
-
|
135
|
-
self.
|
136
|
-
|
137
|
-
def get_next[R: Result](self, *, task: "Task[Any, R]") -> NodeState[R]:
|
138
|
-
if task.name not in self._last_requested_indices:
|
139
|
-
self._last_requested_indices[task.name] = 0
|
140
|
-
else:
|
141
|
-
self._last_requested_indices[task.name] += 1
|
132
|
+
@property
|
133
|
+
def flow_inputs(self) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
134
|
+
if self._flow_args is None or self._flow_kwargs is None:
|
135
|
+
raise Honk("This Flow run has not been executed before")
|
142
136
|
|
143
|
-
return self.
|
137
|
+
return self._flow_args, self._flow_kwargs
|
144
138
|
|
145
139
|
def get_all[R: Result](self, *, task: "Task[Any, R]") -> list[NodeState[R]]:
|
146
140
|
matching_nodes: list[NodeState[R]] = []
|
@@ -166,6 +160,22 @@ class FlowRun:
|
|
166
160
|
last_hash=0,
|
167
161
|
)
|
168
162
|
|
163
|
+
def set_flow_inputs(self, *args: Any, **kwargs: Any) -> None:
|
164
|
+
self._flow_args = args
|
165
|
+
self._flow_kwargs = kwargs
|
166
|
+
|
167
|
+
def add_node_state(self, node_state: NodeState[Any], /) -> None:
|
168
|
+
key = (node_state.task_name, node_state.index)
|
169
|
+
self._node_states[key] = node_state.model_dump_json()
|
170
|
+
|
171
|
+
def get_next[R: Result](self, *, task: "Task[Any, R]") -> NodeState[R]:
|
172
|
+
if task.name not in self._last_requested_indices:
|
173
|
+
self._last_requested_indices[task.name] = 0
|
174
|
+
else:
|
175
|
+
self._last_requested_indices[task.name] += 1
|
176
|
+
|
177
|
+
return self.get(task=task, index=self._last_requested_indices[task.name])
|
178
|
+
|
169
179
|
def start(
|
170
180
|
self,
|
171
181
|
*,
|
@@ -192,25 +202,35 @@ class FlowRun:
|
|
192
202
|
del self._node_states[key]
|
193
203
|
|
194
204
|
def dump(self) -> SerializedFlowRun:
|
205
|
+
flow_args, flow_kwargs = self.flow_inputs
|
206
|
+
|
195
207
|
return SerializedFlowRun(
|
196
208
|
json.dumps(
|
197
209
|
{
|
198
|
-
"
|
199
|
-
|
210
|
+
"node_states": {
|
211
|
+
":".join([task_name, str(index)]): value
|
212
|
+
for (task_name, index), value in self._node_states.items()
|
213
|
+
},
|
214
|
+
"flow_args": list(flow_args),
|
215
|
+
"flow_kwargs": flow_kwargs,
|
200
216
|
}
|
201
217
|
)
|
202
218
|
)
|
203
219
|
|
204
220
|
@classmethod
|
205
|
-
def load(cls,
|
221
|
+
def load(cls, serialized_flow_run: SerializedFlowRun, /) -> Self:
|
206
222
|
flow_run = cls()
|
207
|
-
|
223
|
+
run = json.loads(serialized_flow_run)
|
224
|
+
|
208
225
|
new_node_states: dict[tuple[str, int], str] = {}
|
209
|
-
for key, node_state in
|
226
|
+
for key, node_state in run["node_states"].items():
|
210
227
|
task_name, index = tuple(key.split(":"))
|
211
228
|
new_node_states[(task_name, int(index))] = node_state
|
212
|
-
|
213
229
|
flow_run._node_states = new_node_states
|
230
|
+
|
231
|
+
flow_run._flow_args = tuple(run["flow_args"])
|
232
|
+
flow_run._flow_kwargs = run["flow_kwargs"]
|
233
|
+
|
214
234
|
return flow_run
|
215
235
|
|
216
236
|
|
@@ -264,8 +284,21 @@ class Flow[**P]:
|
|
264
284
|
_current_flow_run.set(old_run)
|
265
285
|
|
266
286
|
async def generate(self, *args: P.args, **kwargs: P.kwargs) -> None:
|
287
|
+
flow_run = _current_flow_run.get()
|
288
|
+
if flow_run is None:
|
289
|
+
raise Honk("No current flow run")
|
290
|
+
|
291
|
+
flow_run.set_flow_inputs(*args, **kwargs)
|
267
292
|
await self._fn(*args, **kwargs)
|
268
293
|
|
294
|
+
async def regenerate(self) -> None:
|
295
|
+
flow_run = _current_flow_run.get()
|
296
|
+
if flow_run is None:
|
297
|
+
raise Honk("No current flow run")
|
298
|
+
|
299
|
+
flow_args, flow_kwargs = flow_run.flow_inputs
|
300
|
+
await self._fn(*flow_args, **flow_kwargs)
|
301
|
+
|
269
302
|
|
270
303
|
class Task[**P, R: Result]:
|
271
304
|
def __init__(
|
@@ -323,7 +356,7 @@ class Task[**P, R: Result]:
|
|
323
356
|
|
324
357
|
result = await self._adapter(conversation=node_state.conversation)
|
325
358
|
node_state.add_result(result=result)
|
326
|
-
flow_run.
|
359
|
+
flow_run.add_node_state(node_state)
|
327
360
|
|
328
361
|
return result
|
329
362
|
|
@@ -331,7 +364,7 @@ class Task[**P, R: Result]:
|
|
331
364
|
flow_run = self.__get_current_flow_run()
|
332
365
|
node_state = flow_run.get_next(task=self)
|
333
366
|
result = await self.generate(node_state, *args, **kwargs)
|
334
|
-
flow_run.
|
367
|
+
flow_run.add_node_state(node_state)
|
335
368
|
return result
|
336
369
|
|
337
370
|
def __hash_task_call(self, *args: P.args, **kwargs: P.kwargs) -> int:
|
goose/result.py
ADDED
@@ -0,0 +1,10 @@
|
|
1
|
+
goose/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
goose/agent.py,sha256=S9nDhK7v4Rsk-ANuq5tIZfZ7TOmCdMEVid2VuK1EWCw,8106
|
3
|
+
goose/errors.py,sha256=-0OyZQJWYTRw5YgnCB2_uorVaUsL6Z0QYQO2FqzCiyg,32
|
4
|
+
goose/flow.py,sha256=ZoMk3TlTLG1KM9ZxxR6lAuzDxzPZgUeWAi5mm4Ch7L4,13428
|
5
|
+
goose/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
+
goose/result.py,sha256=-eZJn-2sPo7rHZ38Sz6IAHXqiJ-Ss39esEoFGimJEBI,155
|
7
|
+
goose/store.py,sha256=4p2BBVAEUS1_Z0iBk5Qk_fPxRQeph64DRzXOFmjIT38,844
|
8
|
+
goose_py-0.4.1.dist-info/METADATA,sha256=AB7TA_rXnj1M_gQ8Y_PnX7mhbdvJl7nlAAWf_S651fk,1106
|
9
|
+
goose_py-0.4.1.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
10
|
+
goose_py-0.4.1.dist-info/RECORD,,
|
goose_py-0.3.12.dist-info/RECORD
DELETED
@@ -1,9 +0,0 @@
|
|
1
|
-
goose/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
goose/agent.py,sha256=pZkKKW2weXEX43p9-pVIpEtwei5nylyQ_YrNZWYFQM0,7687
|
3
|
-
goose/errors.py,sha256=-0OyZQJWYTRw5YgnCB2_uorVaUsL6Z0QYQO2FqzCiyg,32
|
4
|
-
goose/flow.py,sha256=eZMaMQNhXqL5EMJLjDIRyqFKv9K_WV6sJsc50Oms1Vg,12152
|
5
|
-
goose/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
-
goose/store.py,sha256=4p2BBVAEUS1_Z0iBk5Qk_fPxRQeph64DRzXOFmjIT38,844
|
7
|
-
goose_py-0.3.12.dist-info/METADATA,sha256=ATNQi96Z1ZwwicyVMPOcM371cuU5mkEVXn4R-Jm_k-s,1107
|
8
|
-
goose_py-0.3.12.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
9
|
-
goose_py-0.3.12.dist-info/RECORD,,
|
File without changes
|