goose-py 0.2.2__tar.gz → 0.3.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {goose_py-0.2.2 → goose_py-0.3.1}/PKG-INFO +1 -1
- {goose_py-0.2.2 → goose_py-0.3.1}/goose/agent.py +7 -4
- {goose_py-0.2.2 → goose_py-0.3.1}/goose/flow.py +111 -65
- goose_py-0.3.1/goose/py.typed +0 -0
- {goose_py-0.2.2 → goose_py-0.3.1}/pyproject.toml +1 -1
- goose_py-0.2.2/goose/__init__.py +0 -3
- {goose_py-0.2.2 → goose_py-0.3.1}/README.md +0 -0
- /goose_py-0.2.2/goose/py.typed → /goose_py-0.3.1/goose/__init__.py +0 -0
- {goose_py-0.2.2 → goose_py-0.3.1}/goose/errors.py +0 -0
@@ -88,10 +88,13 @@ class AssistantMessage(BaseModel):
|
|
88
88
|
|
89
89
|
|
90
90
|
class SystemMessage(BaseModel):
|
91
|
-
|
91
|
+
parts: list[TextMessagePart | MediaMessagePart]
|
92
92
|
|
93
93
|
def render(self) -> LLMMessage:
|
94
|
-
return {
|
94
|
+
return {
|
95
|
+
"role": "system",
|
96
|
+
"content": [part.render() for part in self.parts],
|
97
|
+
}
|
95
98
|
|
96
99
|
|
97
100
|
class AgentResponse[R: BaseModel](BaseModel):
|
@@ -109,7 +112,7 @@ class AgentResponse[R: BaseModel](BaseModel):
|
|
109
112
|
}
|
110
113
|
|
111
114
|
response: R
|
112
|
-
|
115
|
+
run_id: str
|
113
116
|
flow_name: str
|
114
117
|
task_name: str
|
115
118
|
model: GeminiModel
|
@@ -186,7 +189,7 @@ class Agent:
|
|
186
189
|
end_time = datetime.now()
|
187
190
|
agent_response = AgentResponse(
|
188
191
|
response=parsed_response,
|
189
|
-
|
192
|
+
run_id=self.run_id,
|
190
193
|
flow_name=self.flow_name,
|
191
194
|
task_name=task_name,
|
192
195
|
model=model,
|
@@ -14,10 +14,10 @@ from typing import (
|
|
14
14
|
|
15
15
|
from pydantic import BaseModel, ConfigDict, field_validator
|
16
16
|
|
17
|
-
from goose.agent import UserMessage
|
17
|
+
from goose.agent import Agent, AssistantMessage, LLMMessage, SystemMessage, UserMessage
|
18
18
|
from goose.errors import Honk
|
19
19
|
|
20
|
-
|
20
|
+
SerializedFlowRun = NewType("SerializedFlowRun", str)
|
21
21
|
|
22
22
|
|
23
23
|
class Result(BaseModel):
|
@@ -28,8 +28,9 @@ class GooseResponse[R: Result](BaseModel):
|
|
28
28
|
result: R
|
29
29
|
|
30
30
|
|
31
|
-
class
|
31
|
+
class Conversation[R: Result](BaseModel):
|
32
32
|
messages: list[UserMessage | GooseResponse[R]]
|
33
|
+
context: SystemMessage | None = None
|
33
34
|
|
34
35
|
@field_validator("messages")
|
35
36
|
def alternates_starting_with_result(
|
@@ -56,28 +57,44 @@ class ConversationState[R: Result](BaseModel):
|
|
56
57
|
def awaiting_response(self) -> bool:
|
57
58
|
return len(self.messages) % 2 == 0
|
58
59
|
|
60
|
+
def render(self) -> list[LLMMessage]:
|
61
|
+
messages: list[LLMMessage] = []
|
62
|
+
if self.context is not None:
|
63
|
+
messages.append(self.context.render())
|
64
|
+
|
65
|
+
for message in self.messages:
|
66
|
+
if isinstance(message, UserMessage):
|
67
|
+
messages.append(message.render())
|
68
|
+
else:
|
69
|
+
messages.append(
|
70
|
+
AssistantMessage(text=message.result.model_dump_json()).render()
|
71
|
+
)
|
72
|
+
|
73
|
+
return messages
|
74
|
+
|
59
75
|
|
60
76
|
class IAdapter[ResultT: Result](Protocol):
|
61
|
-
async def __call__(
|
62
|
-
self, *, conversation_state: ConversationState[ResultT]
|
63
|
-
) -> ResultT: ...
|
77
|
+
async def __call__(self, *, conversation: Conversation[ResultT]) -> ResultT: ...
|
64
78
|
|
65
79
|
|
66
80
|
class NodeState[ResultT: Result](BaseModel):
|
67
81
|
task_name: str
|
68
82
|
index: int
|
69
|
-
|
83
|
+
conversation: Conversation[ResultT]
|
70
84
|
last_input_hash: int
|
71
|
-
pinned: bool
|
72
85
|
|
73
86
|
@property
|
74
87
|
def result(self) -> ResultT:
|
75
|
-
last_message = self.
|
88
|
+
last_message = self.conversation.messages[-1]
|
76
89
|
if isinstance(last_message, GooseResponse):
|
77
90
|
return last_message.result
|
78
91
|
else:
|
79
92
|
raise Honk("Node awaiting response, has no result")
|
80
93
|
|
94
|
+
def set_context(self, *, context: SystemMessage) -> Self:
|
95
|
+
self.conversation.context = context
|
96
|
+
return self
|
97
|
+
|
81
98
|
def add_result(
|
82
99
|
self,
|
83
100
|
*,
|
@@ -86,33 +103,42 @@ class NodeState[ResultT: Result](BaseModel):
|
|
86
103
|
overwrite: bool = False,
|
87
104
|
) -> Self:
|
88
105
|
if overwrite:
|
89
|
-
if len(self.
|
90
|
-
self.
|
106
|
+
if len(self.conversation.messages) == 0:
|
107
|
+
self.conversation.messages.append(GooseResponse(result=result))
|
91
108
|
else:
|
92
|
-
self.
|
109
|
+
self.conversation.messages[-1] = GooseResponse(result=result)
|
93
110
|
else:
|
94
|
-
self.
|
111
|
+
self.conversation.messages.append(GooseResponse(result=result))
|
95
112
|
if new_input_hash is not None:
|
96
113
|
self.last_input_hash = new_input_hash
|
97
114
|
return self
|
98
115
|
|
99
116
|
def add_user_message(self, *, message: UserMessage) -> Self:
|
100
|
-
self.
|
101
|
-
return self
|
102
|
-
|
103
|
-
def pin(self) -> Self:
|
104
|
-
self.pinned = True
|
105
|
-
return self
|
106
|
-
|
107
|
-
def unpin(self) -> Self:
|
108
|
-
self.pinned = False
|
117
|
+
self.conversation.messages.append(message)
|
109
118
|
return self
|
110
119
|
|
111
120
|
|
112
|
-
class
|
121
|
+
class FlowRun:
|
113
122
|
def __init__(self) -> None:
|
114
123
|
self._node_states: dict[tuple[str, int], str] = {}
|
115
124
|
self._last_requested_indices: dict[str, int] = {}
|
125
|
+
self._flow_name = ""
|
126
|
+
self._id = ""
|
127
|
+
self._agent: Agent | None = None
|
128
|
+
|
129
|
+
@property
|
130
|
+
def flow_name(self) -> str:
|
131
|
+
return self._flow_name
|
132
|
+
|
133
|
+
@property
|
134
|
+
def id(self) -> str:
|
135
|
+
return self._id
|
136
|
+
|
137
|
+
@property
|
138
|
+
def agent(self) -> Agent:
|
139
|
+
if self._agent is None:
|
140
|
+
raise Honk("Agent is only accessible once a run is started")
|
141
|
+
return self._agent
|
116
142
|
|
117
143
|
def add(self, node_state: NodeState[Any], /) -> None:
|
118
144
|
key = (node_state.task_name, node_state.index)
|
@@ -144,19 +170,24 @@ class FlowState:
|
|
144
170
|
return NodeState[task.result_type](
|
145
171
|
task_name=task.name,
|
146
172
|
index=index or 0,
|
147
|
-
|
173
|
+
conversation=Conversation[task.result_type](messages=[]),
|
148
174
|
last_input_hash=0,
|
149
|
-
pinned=False,
|
150
175
|
)
|
151
176
|
|
152
|
-
|
153
|
-
def run(self) -> Iterator[Self]:
|
177
|
+
def start(self, *, flow_name: str, run_id: str) -> None:
|
154
178
|
self._last_requested_indices = {}
|
155
|
-
|
179
|
+
self._flow_name = flow_name
|
180
|
+
self._id = run_id
|
181
|
+
self._agent = Agent(flow_name=self.flow_name, run_id=self.id)
|
182
|
+
|
183
|
+
def end(self) -> None:
|
156
184
|
self._last_requested_indices = {}
|
185
|
+
self._flow_name = ""
|
186
|
+
self._id = ""
|
187
|
+
self._agent = None
|
157
188
|
|
158
|
-
def dump(self) ->
|
159
|
-
return
|
189
|
+
def dump(self) -> SerializedFlowRun:
|
190
|
+
return SerializedFlowRun(
|
160
191
|
json.dumps(
|
161
192
|
{
|
162
193
|
":".join([task_name, str(index)]): value
|
@@ -166,20 +197,20 @@ class FlowState:
|
|
166
197
|
)
|
167
198
|
|
168
199
|
@classmethod
|
169
|
-
def load(cls,
|
170
|
-
|
171
|
-
raw_node_states = json.loads(
|
200
|
+
def load(cls, run: SerializedFlowRun, /) -> Self:
|
201
|
+
flow_run = cls()
|
202
|
+
raw_node_states = json.loads(run)
|
172
203
|
new_node_states: dict[tuple[str, int], str] = {}
|
173
204
|
for key, node_state in raw_node_states.items():
|
174
205
|
task_name, index = tuple(key.split(":"))
|
175
206
|
new_node_states[(task_name, int(index))] = node_state
|
176
207
|
|
177
|
-
|
178
|
-
return
|
208
|
+
flow_run._node_states = new_node_states
|
209
|
+
return flow_run
|
179
210
|
|
180
211
|
|
181
|
-
|
182
|
-
"
|
212
|
+
_current_flow_run: ContextVar[FlowRun | None] = ContextVar(
|
213
|
+
"current_flow_run", default=None
|
183
214
|
)
|
184
215
|
|
185
216
|
|
@@ -195,25 +226,32 @@ class Flow[**P]:
|
|
195
226
|
return self._name or self._fn.__name__
|
196
227
|
|
197
228
|
@property
|
198
|
-
def
|
199
|
-
|
200
|
-
if
|
201
|
-
raise Honk("No current flow
|
202
|
-
return
|
229
|
+
def current_run(self) -> FlowRun:
|
230
|
+
run = _current_flow_run.get()
|
231
|
+
if run is None:
|
232
|
+
raise Honk("No current flow run")
|
233
|
+
return run
|
203
234
|
|
204
235
|
@contextmanager
|
205
|
-
def
|
206
|
-
|
207
|
-
|
236
|
+
def start_run(
|
237
|
+
self, *, run_id: str, preload: FlowRun | None = None
|
238
|
+
) -> Iterator[FlowRun]:
|
239
|
+
if preload is None:
|
240
|
+
run = FlowRun()
|
241
|
+
else:
|
242
|
+
run = preload
|
243
|
+
|
244
|
+
old_run = _current_flow_run.get()
|
245
|
+
_current_flow_run.set(run)
|
246
|
+
|
247
|
+
run.start(flow_name=self.name, run_id=run_id)
|
248
|
+
yield run
|
249
|
+
run.end()
|
208
250
|
|
209
|
-
|
210
|
-
_current_flow_state.set(state)
|
211
|
-
yield state
|
212
|
-
_current_flow_state.set(old_state)
|
251
|
+
_current_flow_run.set(old_run)
|
213
252
|
|
214
253
|
async def generate(self, *args: P.args, **kwargs: P.kwargs) -> None:
|
215
|
-
|
216
|
-
await self._fn(*args, **kwargs)
|
254
|
+
await self._fn(*args, **kwargs)
|
217
255
|
|
218
256
|
|
219
257
|
class Task[**P, R: Result]:
|
@@ -252,31 +290,39 @@ class Task[**P, R: Result]:
|
|
252
290
|
state.add_result(result=result, new_input_hash=input_hash, overwrite=True)
|
253
291
|
return result
|
254
292
|
else:
|
255
|
-
if not isinstance(state.
|
293
|
+
if not isinstance(state.conversation.messages[-1], GooseResponse):
|
256
294
|
raise Honk(
|
257
295
|
"Conversation must alternate between User and Result messages"
|
258
296
|
)
|
259
297
|
return state.result
|
260
298
|
|
261
|
-
async def
|
262
|
-
self,
|
299
|
+
async def jam(
|
300
|
+
self,
|
301
|
+
*,
|
302
|
+
user_message: UserMessage,
|
303
|
+
context: SystemMessage | None = None,
|
304
|
+
index: int = 0,
|
263
305
|
) -> R:
|
264
|
-
|
306
|
+
flow_run = self.__get_current_flow_run()
|
307
|
+
node_state = flow_run.get(task=self, index=index)
|
265
308
|
if self._adapter is None:
|
266
309
|
raise Honk("No adapter provided for Task")
|
267
310
|
|
311
|
+
if context is not None:
|
312
|
+
node_state.set_context(context=context)
|
268
313
|
node_state.add_user_message(message=user_message)
|
269
|
-
|
314
|
+
|
315
|
+
result = await self._adapter(conversation=node_state.conversation)
|
270
316
|
node_state.add_result(result=result)
|
271
|
-
|
317
|
+
flow_run.add(node_state)
|
272
318
|
|
273
319
|
return result
|
274
320
|
|
275
321
|
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
276
|
-
|
277
|
-
node_state =
|
322
|
+
flow_run = self.__get_current_flow_run()
|
323
|
+
node_state = flow_run.get_next(task=self)
|
278
324
|
result = await self.generate(node_state, *args, **kwargs)
|
279
|
-
|
325
|
+
flow_run.add(node_state)
|
280
326
|
return result
|
281
327
|
|
282
328
|
def __hash_input(self, *args: P.args, **kwargs: P.kwargs) -> int:
|
@@ -286,11 +332,11 @@ class Task[**P, R: Result]:
|
|
286
332
|
except TypeError:
|
287
333
|
raise Honk(f"Unhashable argument to task {self.name}: {args} {kwargs}")
|
288
334
|
|
289
|
-
def
|
290
|
-
|
291
|
-
if
|
292
|
-
raise Honk("No current flow
|
293
|
-
return
|
335
|
+
def __get_current_flow_run(self) -> FlowRun:
|
336
|
+
run = _current_flow_run.get()
|
337
|
+
if run is None:
|
338
|
+
raise Honk("No current flow run")
|
339
|
+
return run
|
294
340
|
|
295
341
|
|
296
342
|
@overload
|
File without changes
|
goose_py-0.2.2/goose/__init__.py
DELETED
File without changes
|
File without changes
|
File without changes
|