goose-py 0.2.1__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {goose_py-0.2.1 → goose_py-0.3.0}/PKG-INFO +1 -1
- {goose_py-0.2.1 → goose_py-0.3.0}/goose/agent.py +20 -11
- {goose_py-0.2.1 → goose_py-0.3.0}/goose/flow.py +100 -65
- goose_py-0.3.0/goose/py.typed +0 -0
- {goose_py-0.2.1 → goose_py-0.3.0}/pyproject.toml +1 -1
- goose_py-0.2.1/goose/__init__.py +0 -3
- {goose_py-0.2.1 → goose_py-0.3.0}/README.md +0 -0
- /goose_py-0.2.1/goose/py.typed → /goose_py-0.3.0/goose/__init__.py +0 -0
- {goose_py-0.2.1 → goose_py-0.3.0}/goose/errors.py +0 -0
@@ -88,10 +88,13 @@ class AssistantMessage(BaseModel):
|
|
88
88
|
|
89
89
|
|
90
90
|
class SystemMessage(BaseModel):
|
91
|
-
|
91
|
+
parts: list[TextMessagePart | MediaMessagePart]
|
92
92
|
|
93
93
|
def render(self) -> LLMMessage:
|
94
|
-
return {
|
94
|
+
return {
|
95
|
+
"role": "system",
|
96
|
+
"content": [part.render() for part in self.parts],
|
97
|
+
}
|
95
98
|
|
96
99
|
|
97
100
|
class AgentResponse[R: BaseModel](BaseModel):
|
@@ -109,7 +112,7 @@ class AgentResponse[R: BaseModel](BaseModel):
|
|
109
112
|
}
|
110
113
|
|
111
114
|
response: R
|
112
|
-
|
115
|
+
run_name: str
|
113
116
|
flow_name: str
|
114
117
|
task_name: str
|
115
118
|
model: GeminiModel
|
@@ -125,14 +128,20 @@ class AgentResponse[R: BaseModel](BaseModel):
|
|
125
128
|
def duration_ms(self) -> int:
|
126
129
|
return int((self.end_time - self.start_time).total_seconds() * 1000)
|
127
130
|
|
131
|
+
@computed_field
|
132
|
+
@property
|
133
|
+
def input_cost(self) -> float:
|
134
|
+
return self.INPUT_CENTS_PER_MILLION_TOKENS[self.model] * self.input_tokens
|
135
|
+
|
136
|
+
@computed_field
|
137
|
+
@property
|
138
|
+
def output_cost(self) -> float:
|
139
|
+
return self.OUTPUT_CENTS_PER_MILLION_TOKENS[self.model] * self.output_tokens
|
140
|
+
|
128
141
|
@computed_field
|
129
142
|
@property
|
130
143
|
def total_cost(self) -> float:
|
131
|
-
|
132
|
-
output_cost = (
|
133
|
-
self.OUTPUT_CENTS_PER_MILLION_TOKENS[self.model] * self.output_tokens
|
134
|
-
)
|
135
|
-
return input_cost + output_cost
|
144
|
+
return self.input_cost + self.output_cost
|
136
145
|
|
137
146
|
|
138
147
|
class Agent:
|
@@ -140,11 +149,11 @@ class Agent:
|
|
140
149
|
self,
|
141
150
|
*,
|
142
151
|
flow_name: str,
|
143
|
-
|
152
|
+
run_name: str,
|
144
153
|
logger: Callable[[AgentResponse[Any]], Awaitable[None]] | None = None,
|
145
154
|
) -> None:
|
146
155
|
self.flow_name = flow_name
|
147
|
-
self.
|
156
|
+
self.run_name = run_name
|
148
157
|
self.logger = logger
|
149
158
|
|
150
159
|
async def __call__[R: BaseModel](
|
@@ -180,7 +189,7 @@ class Agent:
|
|
180
189
|
end_time = datetime.now()
|
181
190
|
agent_response = AgentResponse(
|
182
191
|
response=parsed_response,
|
183
|
-
|
192
|
+
run_name=self.run_name,
|
184
193
|
flow_name=self.flow_name,
|
185
194
|
task_name=task_name,
|
186
195
|
model=model,
|
@@ -14,10 +14,10 @@ from typing import (
|
|
14
14
|
|
15
15
|
from pydantic import BaseModel, ConfigDict, field_validator
|
16
16
|
|
17
|
-
from goose.agent import UserMessage
|
17
|
+
from goose.agent import Agent, AssistantMessage, LLMMessage, SystemMessage, UserMessage
|
18
18
|
from goose.errors import Honk
|
19
19
|
|
20
|
-
|
20
|
+
SerializedFlowRun = NewType("SerializedFlowRun", str)
|
21
21
|
|
22
22
|
|
23
23
|
class Result(BaseModel):
|
@@ -28,8 +28,9 @@ class GooseResponse[R: Result](BaseModel):
|
|
28
28
|
result: R
|
29
29
|
|
30
30
|
|
31
|
-
class
|
31
|
+
class Conversation[R: Result](BaseModel):
|
32
32
|
messages: list[UserMessage | GooseResponse[R]]
|
33
|
+
context: SystemMessage | None = None
|
33
34
|
|
34
35
|
@field_validator("messages")
|
35
36
|
def alternates_starting_with_result(
|
@@ -56,28 +57,44 @@ class ConversationState[R: Result](BaseModel):
|
|
56
57
|
def awaiting_response(self) -> bool:
|
57
58
|
return len(self.messages) % 2 == 0
|
58
59
|
|
60
|
+
def render(self) -> list[LLMMessage]:
|
61
|
+
messages: list[LLMMessage] = []
|
62
|
+
if self.context is not None:
|
63
|
+
messages.append(self.context.render())
|
64
|
+
|
65
|
+
for message in self.messages:
|
66
|
+
if isinstance(message, UserMessage):
|
67
|
+
messages.append(message.render())
|
68
|
+
else:
|
69
|
+
messages.append(
|
70
|
+
AssistantMessage(text=message.result.model_dump_json()).render()
|
71
|
+
)
|
72
|
+
|
73
|
+
return messages
|
74
|
+
|
59
75
|
|
60
76
|
class IAdapter[ResultT: Result](Protocol):
|
61
|
-
async def __call__(
|
62
|
-
self, *, conversation_state: ConversationState[ResultT]
|
63
|
-
) -> ResultT: ...
|
77
|
+
async def __call__(self, *, conversation: Conversation[ResultT]) -> ResultT: ...
|
64
78
|
|
65
79
|
|
66
80
|
class NodeState[ResultT: Result](BaseModel):
|
67
81
|
task_name: str
|
68
82
|
index: int
|
69
|
-
|
83
|
+
conversation: Conversation[ResultT]
|
70
84
|
last_input_hash: int
|
71
|
-
pinned: bool
|
72
85
|
|
73
86
|
@property
|
74
87
|
def result(self) -> ResultT:
|
75
|
-
last_message = self.
|
88
|
+
last_message = self.conversation.messages[-1]
|
76
89
|
if isinstance(last_message, GooseResponse):
|
77
90
|
return last_message.result
|
78
91
|
else:
|
79
92
|
raise Honk("Node awaiting response, has no result")
|
80
93
|
|
94
|
+
def set_context(self, *, context: SystemMessage) -> Self:
|
95
|
+
self.conversation.context = context
|
96
|
+
return self
|
97
|
+
|
81
98
|
def add_result(
|
82
99
|
self,
|
83
100
|
*,
|
@@ -86,33 +103,37 @@ class NodeState[ResultT: Result](BaseModel):
|
|
86
103
|
overwrite: bool = False,
|
87
104
|
) -> Self:
|
88
105
|
if overwrite:
|
89
|
-
if len(self.
|
90
|
-
self.
|
106
|
+
if len(self.conversation.messages) == 0:
|
107
|
+
self.conversation.messages.append(GooseResponse(result=result))
|
91
108
|
else:
|
92
|
-
self.
|
109
|
+
self.conversation.messages[-1] = GooseResponse(result=result)
|
93
110
|
else:
|
94
|
-
self.
|
111
|
+
self.conversation.messages.append(GooseResponse(result=result))
|
95
112
|
if new_input_hash is not None:
|
96
113
|
self.last_input_hash = new_input_hash
|
97
114
|
return self
|
98
115
|
|
99
116
|
def add_user_message(self, *, message: UserMessage) -> Self:
|
100
|
-
self.
|
101
|
-
return self
|
102
|
-
|
103
|
-
def pin(self) -> Self:
|
104
|
-
self.pinned = True
|
105
|
-
return self
|
106
|
-
|
107
|
-
def unpin(self) -> Self:
|
108
|
-
self.pinned = False
|
117
|
+
self.conversation.messages.append(message)
|
109
118
|
return self
|
110
119
|
|
111
120
|
|
112
|
-
class
|
121
|
+
class FlowRun:
|
113
122
|
def __init__(self) -> None:
|
114
123
|
self._node_states: dict[tuple[str, int], str] = {}
|
115
124
|
self._last_requested_indices: dict[str, int] = {}
|
125
|
+
self._name = ""
|
126
|
+
self._agent: Agent | None = None
|
127
|
+
|
128
|
+
@property
|
129
|
+
def name(self) -> str:
|
130
|
+
return self._name
|
131
|
+
|
132
|
+
@property
|
133
|
+
def agent(self) -> Agent:
|
134
|
+
if self._agent is None:
|
135
|
+
raise Honk("Agent is only accessible once a run is started")
|
136
|
+
return self._agent
|
116
137
|
|
117
138
|
def add(self, node_state: NodeState[Any], /) -> None:
|
118
139
|
key = (node_state.task_name, node_state.index)
|
@@ -144,19 +165,22 @@ class FlowState:
|
|
144
165
|
return NodeState[task.result_type](
|
145
166
|
task_name=task.name,
|
146
167
|
index=index or 0,
|
147
|
-
|
168
|
+
conversation=Conversation[task.result_type](messages=[]),
|
148
169
|
last_input_hash=0,
|
149
|
-
pinned=False,
|
150
170
|
)
|
151
171
|
|
152
|
-
|
153
|
-
def run(self) -> Iterator[Self]:
|
172
|
+
def start(self, *, name: str) -> None:
|
154
173
|
self._last_requested_indices = {}
|
155
|
-
|
174
|
+
self._name = name
|
175
|
+
self._agent = Agent(flow_name=self.name, run_name=name)
|
176
|
+
|
177
|
+
def end(self) -> None:
|
156
178
|
self._last_requested_indices = {}
|
179
|
+
self._name = ""
|
180
|
+
self._agent = None
|
157
181
|
|
158
|
-
def dump(self) ->
|
159
|
-
return
|
182
|
+
def dump(self) -> SerializedFlowRun:
|
183
|
+
return SerializedFlowRun(
|
160
184
|
json.dumps(
|
161
185
|
{
|
162
186
|
":".join([task_name, str(index)]): value
|
@@ -166,20 +190,20 @@ class FlowState:
|
|
166
190
|
)
|
167
191
|
|
168
192
|
@classmethod
|
169
|
-
def load(cls,
|
170
|
-
|
171
|
-
raw_node_states = json.loads(
|
193
|
+
def load(cls, run: SerializedFlowRun, /) -> Self:
|
194
|
+
flow_run = cls()
|
195
|
+
raw_node_states = json.loads(run)
|
172
196
|
new_node_states: dict[tuple[str, int], str] = {}
|
173
197
|
for key, node_state in raw_node_states.items():
|
174
198
|
task_name, index = tuple(key.split(":"))
|
175
199
|
new_node_states[(task_name, int(index))] = node_state
|
176
200
|
|
177
|
-
|
178
|
-
return
|
201
|
+
flow_run._node_states = new_node_states
|
202
|
+
return flow_run
|
179
203
|
|
180
204
|
|
181
|
-
|
182
|
-
"
|
205
|
+
_current_flow_run: ContextVar[FlowRun | None] = ContextVar(
|
206
|
+
"current_flow_run", default=None
|
183
207
|
)
|
184
208
|
|
185
209
|
|
@@ -195,25 +219,28 @@ class Flow[**P]:
|
|
195
219
|
return self._name or self._fn.__name__
|
196
220
|
|
197
221
|
@property
|
198
|
-
def
|
199
|
-
|
200
|
-
if
|
201
|
-
raise Honk("No current flow
|
202
|
-
return
|
222
|
+
def current_run(self) -> FlowRun:
|
223
|
+
run = _current_flow_run.get()
|
224
|
+
if run is None:
|
225
|
+
raise Honk("No current flow run")
|
226
|
+
return run
|
203
227
|
|
204
228
|
@contextmanager
|
205
|
-
def
|
206
|
-
if
|
207
|
-
|
229
|
+
def start_run(self, *, name: str, run: FlowRun | None = None) -> Iterator[FlowRun]:
|
230
|
+
if run is None:
|
231
|
+
run = FlowRun()
|
232
|
+
|
233
|
+
old_run = _current_flow_run.get()
|
234
|
+
_current_flow_run.set(run)
|
235
|
+
|
236
|
+
run.start(name=name)
|
237
|
+
yield run
|
238
|
+
run.end()
|
208
239
|
|
209
|
-
|
210
|
-
_current_flow_state.set(state)
|
211
|
-
yield state
|
212
|
-
_current_flow_state.set(old_state)
|
240
|
+
_current_flow_run.set(old_run)
|
213
241
|
|
214
242
|
async def generate(self, *args: P.args, **kwargs: P.kwargs) -> None:
|
215
|
-
|
216
|
-
await self._fn(*args, **kwargs)
|
243
|
+
await self._fn(*args, **kwargs)
|
217
244
|
|
218
245
|
|
219
246
|
class Task[**P, R: Result]:
|
@@ -252,31 +279,39 @@ class Task[**P, R: Result]:
|
|
252
279
|
state.add_result(result=result, new_input_hash=input_hash, overwrite=True)
|
253
280
|
return result
|
254
281
|
else:
|
255
|
-
if not isinstance(state.
|
282
|
+
if not isinstance(state.conversation.messages[-1], GooseResponse):
|
256
283
|
raise Honk(
|
257
284
|
"Conversation must alternate between User and Result messages"
|
258
285
|
)
|
259
286
|
return state.result
|
260
287
|
|
261
|
-
async def
|
262
|
-
self,
|
288
|
+
async def jam(
|
289
|
+
self,
|
290
|
+
*,
|
291
|
+
user_message: UserMessage,
|
292
|
+
context: SystemMessage | None = None,
|
293
|
+
index: int = 0,
|
263
294
|
) -> R:
|
264
|
-
|
295
|
+
flow_run = self.__get_current_flow_run()
|
296
|
+
node_state = flow_run.get(task=self, index=index)
|
265
297
|
if self._adapter is None:
|
266
298
|
raise Honk("No adapter provided for Task")
|
267
299
|
|
300
|
+
if context is not None:
|
301
|
+
node_state.set_context(context=context)
|
268
302
|
node_state.add_user_message(message=user_message)
|
269
|
-
|
303
|
+
|
304
|
+
result = await self._adapter(conversation=node_state.conversation)
|
270
305
|
node_state.add_result(result=result)
|
271
|
-
|
306
|
+
flow_run.add(node_state)
|
272
307
|
|
273
308
|
return result
|
274
309
|
|
275
310
|
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
276
|
-
|
277
|
-
node_state =
|
311
|
+
flow_run = self.__get_current_flow_run()
|
312
|
+
node_state = flow_run.get_next(task=self)
|
278
313
|
result = await self.generate(node_state, *args, **kwargs)
|
279
|
-
|
314
|
+
flow_run.add(node_state)
|
280
315
|
return result
|
281
316
|
|
282
317
|
def __hash_input(self, *args: P.args, **kwargs: P.kwargs) -> int:
|
@@ -286,11 +321,11 @@ class Task[**P, R: Result]:
|
|
286
321
|
except TypeError:
|
287
322
|
raise Honk(f"Unhashable argument to task {self.name}: {args} {kwargs}")
|
288
323
|
|
289
|
-
def
|
290
|
-
|
291
|
-
if
|
292
|
-
raise Honk("No current flow
|
293
|
-
return
|
324
|
+
def __get_current_flow_run(self) -> FlowRun:
|
325
|
+
run = _current_flow_run.get()
|
326
|
+
if run is None:
|
327
|
+
raise Honk("No current flow run")
|
328
|
+
return run
|
294
329
|
|
295
330
|
|
296
331
|
@overload
|
File without changes
|
goose_py-0.2.1/goose/__init__.py
DELETED
File without changes
|
File without changes
|
File without changes
|