goose-py 0.3.1__tar.gz → 0.3.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {goose_py-0.3.1 → goose_py-0.3.2}/PKG-INFO +1 -1
- {goose_py-0.3.1 → goose_py-0.3.2}/goose/agent.py +7 -3
- {goose_py-0.3.1 → goose_py-0.3.2}/goose/flow.py +34 -9
- {goose_py-0.3.1 → goose_py-0.3.2}/pyproject.toml +1 -1
- {goose_py-0.3.1 → goose_py-0.3.2}/README.md +0 -0
- {goose_py-0.3.1 → goose_py-0.3.2}/goose/__init__.py +0 -0
- {goose_py-0.3.1 → goose_py-0.3.2}/goose/errors.py +0 -0
- {goose_py-0.3.1 → goose_py-0.3.2}/goose/py.typed +0 -0
@@ -2,7 +2,7 @@ import base64
|
|
2
2
|
import logging
|
3
3
|
from datetime import datetime
|
4
4
|
from enum import StrEnum
|
5
|
-
from typing import Any,
|
5
|
+
from typing import Any, ClassVar, Literal, NotRequired, Protocol, TypedDict
|
6
6
|
|
7
7
|
from litellm import acompletion
|
8
8
|
from pydantic import BaseModel, computed_field
|
@@ -144,13 +144,17 @@ class AgentResponse[R: BaseModel](BaseModel):
|
|
144
144
|
return self.input_cost + self.output_cost
|
145
145
|
|
146
146
|
|
147
|
+
class IAgentLogger(Protocol):
|
148
|
+
async def __call__(self, *, response: AgentResponse[Any]) -> None: ...
|
149
|
+
|
150
|
+
|
147
151
|
class Agent:
|
148
152
|
def __init__(
|
149
153
|
self,
|
150
154
|
*,
|
151
155
|
flow_name: str,
|
152
156
|
run_id: str,
|
153
|
-
logger:
|
157
|
+
logger: IAgentLogger | None = None,
|
154
158
|
) -> None:
|
155
159
|
self.flow_name = flow_name
|
156
160
|
self.run_id = run_id
|
@@ -202,7 +206,7 @@ class Agent:
|
|
202
206
|
)
|
203
207
|
|
204
208
|
if self.logger is not None:
|
205
|
-
await self.logger(agent_response)
|
209
|
+
await self.logger(response=agent_response)
|
206
210
|
else:
|
207
211
|
logging.info(agent_response.model_dump())
|
208
212
|
|
@@ -14,7 +14,14 @@ from typing import (
|
|
14
14
|
|
15
15
|
from pydantic import BaseModel, ConfigDict, field_validator
|
16
16
|
|
17
|
-
from goose.agent import
|
17
|
+
from goose.agent import (
|
18
|
+
Agent,
|
19
|
+
AssistantMessage,
|
20
|
+
IAgentLogger,
|
21
|
+
LLMMessage,
|
22
|
+
SystemMessage,
|
23
|
+
UserMessage,
|
24
|
+
)
|
18
25
|
from goose.errors import Honk
|
19
26
|
|
20
27
|
SerializedFlowRun = NewType("SerializedFlowRun", str)
|
@@ -174,11 +181,19 @@ class FlowRun:
|
|
174
181
|
last_input_hash=0,
|
175
182
|
)
|
176
183
|
|
177
|
-
def start(
|
184
|
+
def start(
|
185
|
+
self,
|
186
|
+
*,
|
187
|
+
flow_name: str,
|
188
|
+
run_id: str,
|
189
|
+
agent_logger: IAgentLogger | None = None,
|
190
|
+
) -> None:
|
178
191
|
self._last_requested_indices = {}
|
179
192
|
self._flow_name = flow_name
|
180
193
|
self._id = run_id
|
181
|
-
self._agent = Agent(
|
194
|
+
self._agent = Agent(
|
195
|
+
flow_name=self.flow_name, run_id=self.id, logger=agent_logger
|
196
|
+
)
|
182
197
|
|
183
198
|
def end(self) -> None:
|
184
199
|
self._last_requested_indices = {}
|
@@ -216,10 +231,16 @@ _current_flow_run: ContextVar[FlowRun | None] = ContextVar(
|
|
216
231
|
|
217
232
|
class Flow[**P]:
|
218
233
|
def __init__(
|
219
|
-
self,
|
234
|
+
self,
|
235
|
+
fn: Callable[P, Awaitable[None]],
|
236
|
+
/,
|
237
|
+
*,
|
238
|
+
name: str | None = None,
|
239
|
+
agent_logger: IAgentLogger | None = None,
|
220
240
|
) -> None:
|
221
241
|
self._fn = fn
|
222
242
|
self._name = name
|
243
|
+
self._agent_logger = agent_logger
|
223
244
|
|
224
245
|
@property
|
225
246
|
def name(self) -> str:
|
@@ -244,7 +265,7 @@ class Flow[**P]:
|
|
244
265
|
old_run = _current_flow_run.get()
|
245
266
|
_current_flow_run.set(run)
|
246
267
|
|
247
|
-
run.start(flow_name=self.name, run_id=run_id)
|
268
|
+
run.start(flow_name=self.name, run_id=run_id, agent_logger=self._agent_logger)
|
248
269
|
yield run
|
249
270
|
run.end()
|
250
271
|
|
@@ -365,16 +386,20 @@ def task[**P, R: Result](
|
|
365
386
|
def flow[**P](fn: Callable[P, Awaitable[None]], /) -> Flow[P]: ...
|
366
387
|
@overload
|
367
388
|
def flow[**P](
|
368
|
-
*, name: str | None = None
|
389
|
+
*, name: str | None = None, agent_logger: IAgentLogger | None = None
|
369
390
|
) -> Callable[[Callable[P, Awaitable[None]]], Flow[P]]: ...
|
370
391
|
def flow[**P](
|
371
|
-
fn: Callable[P, Awaitable[None]] | None = None,
|
392
|
+
fn: Callable[P, Awaitable[None]] | None = None,
|
393
|
+
/,
|
394
|
+
*,
|
395
|
+
name: str | None = None,
|
396
|
+
agent_logger: IAgentLogger | None = None,
|
372
397
|
) -> Flow[P] | Callable[[Callable[P, Awaitable[None]]], Flow[P]]:
|
373
398
|
if fn is None:
|
374
399
|
|
375
400
|
def decorator(fn: Callable[P, Awaitable[None]]) -> Flow[P]:
|
376
|
-
return Flow(fn, name=name)
|
401
|
+
return Flow(fn, name=name, agent_logger=agent_logger)
|
377
402
|
|
378
403
|
return decorator
|
379
404
|
|
380
|
-
return Flow(fn, name=name)
|
405
|
+
return Flow(fn, name=name, agent_logger=agent_logger)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|