goose-py 0.1.4__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {goose_py-0.1.4 → goose_py-0.2.0}/PKG-INFO +2 -2
- {goose_py-0.1.4 → goose_py-0.2.0}/README.md +1 -1
- goose_py-0.2.0/goose/__init__.py +3 -0
- goose_py-0.1.4/goose/types.py → goose_py-0.2.0/goose/agent.py +64 -1
- goose_py-0.2.0/goose/flow.py +334 -0
- {goose_py-0.1.4 → goose_py-0.2.0}/pyproject.toml +1 -1
- goose_py-0.1.4/goose/__init__.py +0 -0
- goose_py-0.1.4/goose/agent.py +0 -75
- goose_py-0.1.4/goose/conversation.py +0 -45
- goose_py-0.1.4/goose/core.py +0 -294
- {goose_py-0.1.4 → goose_py-0.2.0}/goose/errors.py +0 -0
- {goose_py-0.1.4 → goose_py-0.2.0}/goose/py.typed +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: goose-py
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: A tool for AI workflows based on human-computer collaboration and structured output.
|
5
5
|
Home-page: https://github.com/chelle-ai/goose
|
6
6
|
Keywords: ai,yaml,configuration,llm
|
@@ -25,7 +25,7 @@ Project-URL: Documentation, https://github.com/chelle-ai/goose
|
|
25
25
|
Project-URL: Repository, https://github.com/chelle-ai/goose
|
26
26
|
Description-Content-Type: text/markdown
|
27
27
|
|
28
|
-
#
|
28
|
+
# Goose
|
29
29
|
|
30
30
|
Docs to come.
|
31
31
|
|
@@ -1,8 +1,10 @@
|
|
1
1
|
import base64
|
2
|
+
import logging
|
2
3
|
from datetime import datetime
|
3
4
|
from enum import StrEnum
|
4
|
-
from typing import ClassVar, Literal, NotRequired, TypedDict
|
5
|
+
from typing import Any, Callable, ClassVar, Literal, NotRequired, TypedDict
|
5
6
|
|
7
|
+
from litellm import acompletion
|
6
8
|
from pydantic import BaseModel, computed_field
|
7
9
|
|
8
10
|
|
@@ -131,3 +133,64 @@ class AgentResponse[R: BaseModel](BaseModel):
|
|
131
133
|
self.OUTPUT_CENTS_PER_MILLION_TOKENS[self.model] * self.output_tokens
|
132
134
|
)
|
133
135
|
return input_cost + output_cost
|
136
|
+
|
137
|
+
|
138
|
+
class Agent:
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
*,
|
142
|
+
flow_name: str,
|
143
|
+
run_id: str,
|
144
|
+
logger: Callable[[AgentResponse[Any]], None] | None = None,
|
145
|
+
) -> None:
|
146
|
+
self.flow_name = flow_name
|
147
|
+
self.run_id = run_id
|
148
|
+
self.logger = logger or logging.info
|
149
|
+
|
150
|
+
async def __call__[R: BaseModel](
|
151
|
+
self,
|
152
|
+
*,
|
153
|
+
messages: list[UserMessage | AssistantMessage],
|
154
|
+
model: GeminiModel,
|
155
|
+
response_model: type[R],
|
156
|
+
task_name: str,
|
157
|
+
system: SystemMessage | None = None,
|
158
|
+
) -> R:
|
159
|
+
start_time = datetime.now()
|
160
|
+
rendered_messages = [message.render() for message in messages]
|
161
|
+
if system is not None:
|
162
|
+
rendered_messages.insert(0, system.render())
|
163
|
+
|
164
|
+
response = await acompletion(
|
165
|
+
model=model.value,
|
166
|
+
messages=rendered_messages,
|
167
|
+
response_format={
|
168
|
+
"type": "json_object",
|
169
|
+
"response_schema": response_model.model_json_schema(),
|
170
|
+
"enforce_validation": True,
|
171
|
+
},
|
172
|
+
)
|
173
|
+
|
174
|
+
if len(response.choices) == 0:
|
175
|
+
raise RuntimeError("No content returned from LLM call.")
|
176
|
+
|
177
|
+
parsed_response = response_model.model_validate_json(
|
178
|
+
response.choices[0].message.content
|
179
|
+
)
|
180
|
+
end_time = datetime.now()
|
181
|
+
agent_response = AgentResponse(
|
182
|
+
response=parsed_response,
|
183
|
+
id=self.run_id,
|
184
|
+
flow_name=self.flow_name,
|
185
|
+
task_name=task_name,
|
186
|
+
model=model,
|
187
|
+
system=system,
|
188
|
+
input_messages=messages,
|
189
|
+
input_tokens=response.usage.prompt_tokens,
|
190
|
+
output_tokens=response.usage.completion_tokens,
|
191
|
+
start_time=start_time,
|
192
|
+
end_time=end_time,
|
193
|
+
)
|
194
|
+
|
195
|
+
self.logger(agent_response)
|
196
|
+
return agent_response.response
|
@@ -0,0 +1,334 @@
|
|
1
|
+
import json
|
2
|
+
from contextlib import contextmanager
|
3
|
+
from contextvars import ContextVar
|
4
|
+
from typing import (
|
5
|
+
Any,
|
6
|
+
Awaitable,
|
7
|
+
Callable,
|
8
|
+
Iterator,
|
9
|
+
NewType,
|
10
|
+
Protocol,
|
11
|
+
Self,
|
12
|
+
overload,
|
13
|
+
)
|
14
|
+
|
15
|
+
from pydantic import BaseModel, ConfigDict, field_validator
|
16
|
+
|
17
|
+
from goose.agent import UserMessage
|
18
|
+
from goose.errors import Honk
|
19
|
+
|
20
|
+
SerializedFlowState = NewType("SerializedFlowState", str)
|
21
|
+
|
22
|
+
|
23
|
+
class Result(BaseModel):
|
24
|
+
model_config = ConfigDict(frozen=True)
|
25
|
+
|
26
|
+
|
27
|
+
class GooseResponse[R: Result](BaseModel):
|
28
|
+
result: R
|
29
|
+
|
30
|
+
|
31
|
+
class ConversationState[R: Result](BaseModel):
|
32
|
+
messages: list[UserMessage | GooseResponse[R]]
|
33
|
+
|
34
|
+
@field_validator("messages")
|
35
|
+
def alternates_starting_with_result(
|
36
|
+
cls, messages: list[UserMessage | GooseResponse[R]]
|
37
|
+
) -> list[UserMessage | GooseResponse[R]]:
|
38
|
+
if len(messages) == 0:
|
39
|
+
return messages
|
40
|
+
elif isinstance(messages[0], UserMessage):
|
41
|
+
raise Honk(
|
42
|
+
"User cannot start a conversation on a Task, must begin with a Result"
|
43
|
+
)
|
44
|
+
|
45
|
+
last_message_type: type[UserMessage | GooseResponse[R]] = type(messages[0])
|
46
|
+
for message in messages[1:]:
|
47
|
+
if isinstance(message, last_message_type):
|
48
|
+
raise Honk(
|
49
|
+
"Conversation must alternate between User and Result messages"
|
50
|
+
)
|
51
|
+
last_message_type = type(message)
|
52
|
+
|
53
|
+
return messages
|
54
|
+
|
55
|
+
@property
|
56
|
+
def awaiting_response(self) -> bool:
|
57
|
+
return len(self.messages) % 2 == 0
|
58
|
+
|
59
|
+
|
60
|
+
class IAdapter[ResultT: Result](Protocol):
|
61
|
+
async def __call__(
|
62
|
+
self, *, conversation_state: ConversationState[ResultT]
|
63
|
+
) -> ResultT: ...
|
64
|
+
|
65
|
+
|
66
|
+
class NodeState[ResultT: Result](BaseModel):
|
67
|
+
task_name: str
|
68
|
+
index: int
|
69
|
+
conversation_state: ConversationState[ResultT]
|
70
|
+
last_input_hash: int
|
71
|
+
pinned: bool
|
72
|
+
|
73
|
+
@property
|
74
|
+
def result(self) -> ResultT:
|
75
|
+
last_message = self.conversation_state.messages[-1]
|
76
|
+
if isinstance(last_message, GooseResponse):
|
77
|
+
return last_message.result
|
78
|
+
else:
|
79
|
+
raise Honk("Node awaiting response, has no result")
|
80
|
+
|
81
|
+
def add_result(
|
82
|
+
self,
|
83
|
+
*,
|
84
|
+
result: ResultT,
|
85
|
+
new_input_hash: int | None = None,
|
86
|
+
overwrite: bool = False,
|
87
|
+
) -> Self:
|
88
|
+
if overwrite:
|
89
|
+
if len(self.conversation_state.messages) == 0:
|
90
|
+
self.conversation_state.messages.append(GooseResponse(result=result))
|
91
|
+
else:
|
92
|
+
self.conversation_state.messages[-1] = GooseResponse(result=result)
|
93
|
+
else:
|
94
|
+
self.conversation_state.messages.append(GooseResponse(result=result))
|
95
|
+
if new_input_hash is not None:
|
96
|
+
self.last_input_hash = new_input_hash
|
97
|
+
return self
|
98
|
+
|
99
|
+
def add_user_message(self, *, message: UserMessage) -> Self:
|
100
|
+
self.conversation_state.messages.append(message)
|
101
|
+
return self
|
102
|
+
|
103
|
+
def pin(self) -> Self:
|
104
|
+
self.pinned = True
|
105
|
+
return self
|
106
|
+
|
107
|
+
def unpin(self) -> Self:
|
108
|
+
self.pinned = False
|
109
|
+
return self
|
110
|
+
|
111
|
+
|
112
|
+
class FlowState:
|
113
|
+
def __init__(self) -> None:
|
114
|
+
self._node_states: dict[tuple[str, int], str] = {}
|
115
|
+
self._last_requested_indices: dict[str, int] = {}
|
116
|
+
|
117
|
+
def add(self, node_state: NodeState[Any], /) -> None:
|
118
|
+
key = (node_state.task_name, node_state.index)
|
119
|
+
self._node_states[key] = node_state.model_dump_json()
|
120
|
+
|
121
|
+
def get_next[R: Result](self, *, task: "Task[Any, R]") -> NodeState[R]:
|
122
|
+
if task.name not in self._last_requested_indices:
|
123
|
+
self._last_requested_indices[task.name] = 0
|
124
|
+
else:
|
125
|
+
self._last_requested_indices[task.name] += 1
|
126
|
+
|
127
|
+
return self.get(task=task, index=self._last_requested_indices[task.name])
|
128
|
+
|
129
|
+
def get_all[R: Result](self, *, task: "Task[Any, R]") -> list[NodeState[R]]:
|
130
|
+
matching_nodes: list[NodeState[R]] = []
|
131
|
+
for key, node_state in self._node_states.items():
|
132
|
+
if key[0] == task.name:
|
133
|
+
matching_nodes.append(
|
134
|
+
NodeState[task.result_type].model_validate_json(node_state)
|
135
|
+
)
|
136
|
+
return matching_nodes
|
137
|
+
|
138
|
+
def get[R: Result](self, *, task: "Task[Any, R]", index: int = 0) -> NodeState[R]:
|
139
|
+
if (
|
140
|
+
existing_node_state := self._node_states.get((task.name, index))
|
141
|
+
) is not None:
|
142
|
+
return NodeState[task.result_type].model_validate_json(existing_node_state)
|
143
|
+
else:
|
144
|
+
return NodeState[task.result_type](
|
145
|
+
task_name=task.name,
|
146
|
+
index=index or 0,
|
147
|
+
conversation_state=ConversationState[task.result_type](messages=[]),
|
148
|
+
last_input_hash=0,
|
149
|
+
pinned=False,
|
150
|
+
)
|
151
|
+
|
152
|
+
@contextmanager
|
153
|
+
def run(self) -> Iterator[Self]:
|
154
|
+
self._last_requested_indices = {}
|
155
|
+
yield self
|
156
|
+
self._last_requested_indices = {}
|
157
|
+
|
158
|
+
def dump(self) -> SerializedFlowState:
|
159
|
+
return SerializedFlowState(
|
160
|
+
json.dumps(
|
161
|
+
{
|
162
|
+
":".join([task_name, str(index)]): value
|
163
|
+
for (task_name, index), value in self._node_states.items()
|
164
|
+
}
|
165
|
+
)
|
166
|
+
)
|
167
|
+
|
168
|
+
@classmethod
|
169
|
+
def load(cls, state: SerializedFlowState) -> Self:
|
170
|
+
flow_state = cls()
|
171
|
+
raw_node_states = json.loads(state)
|
172
|
+
new_node_states: dict[tuple[str, int], str] = {}
|
173
|
+
for key, node_state in raw_node_states.items():
|
174
|
+
task_name, index = tuple(key.split(":"))
|
175
|
+
new_node_states[(task_name, int(index))] = node_state
|
176
|
+
|
177
|
+
flow_state._node_states = new_node_states
|
178
|
+
return flow_state
|
179
|
+
|
180
|
+
|
181
|
+
_current_flow_state: ContextVar[FlowState | None] = ContextVar(
|
182
|
+
"current_flow_state", default=None
|
183
|
+
)
|
184
|
+
|
185
|
+
|
186
|
+
class Flow[**P]:
|
187
|
+
def __init__(
|
188
|
+
self, fn: Callable[P, Awaitable[None]], /, *, name: str | None = None
|
189
|
+
) -> None:
|
190
|
+
self._fn = fn
|
191
|
+
self._name = name
|
192
|
+
|
193
|
+
@property
|
194
|
+
def name(self) -> str:
|
195
|
+
return self._name or self._fn.__name__
|
196
|
+
|
197
|
+
@property
|
198
|
+
def state(self) -> FlowState:
|
199
|
+
state = _current_flow_state.get()
|
200
|
+
if state is None:
|
201
|
+
raise Honk("No current flow state")
|
202
|
+
return state
|
203
|
+
|
204
|
+
@contextmanager
|
205
|
+
def run(self, *, state: FlowState | None = None) -> Iterator[FlowState]:
|
206
|
+
if state is None:
|
207
|
+
state = FlowState()
|
208
|
+
|
209
|
+
old_state = _current_flow_state.get()
|
210
|
+
_current_flow_state.set(state)
|
211
|
+
yield state
|
212
|
+
_current_flow_state.set(old_state)
|
213
|
+
|
214
|
+
async def generate(self, *args: P.args, **kwargs: P.kwargs) -> None:
|
215
|
+
with self.state.run():
|
216
|
+
await self._fn(*args, **kwargs)
|
217
|
+
|
218
|
+
|
219
|
+
class Task[**P, R: Result]:
|
220
|
+
def __init__(
|
221
|
+
self,
|
222
|
+
generator: Callable[P, Awaitable[R]],
|
223
|
+
/,
|
224
|
+
*,
|
225
|
+
retries: int = 0,
|
226
|
+
) -> None:
|
227
|
+
self._generator = generator
|
228
|
+
self._adapter: IAdapter[R] | None = None
|
229
|
+
self._retries = retries
|
230
|
+
|
231
|
+
@property
|
232
|
+
def result_type(self) -> type[R]:
|
233
|
+
result_type = self._generator.__annotations__.get("return")
|
234
|
+
if result_type is None:
|
235
|
+
raise Honk(f"Task {self.name} has no return type annotation")
|
236
|
+
return result_type
|
237
|
+
|
238
|
+
@property
|
239
|
+
def name(self) -> str:
|
240
|
+
return self._generator.__name__
|
241
|
+
|
242
|
+
def adapter(self, adapter: IAdapter[R]) -> Self:
|
243
|
+
self._adapter = adapter
|
244
|
+
return self
|
245
|
+
|
246
|
+
async def generate(
|
247
|
+
self, state: NodeState[R], *args: P.args, **kwargs: P.kwargs
|
248
|
+
) -> R:
|
249
|
+
input_hash = self.__hash_input(*args, **kwargs)
|
250
|
+
if input_hash != state.last_input_hash:
|
251
|
+
result = await self._generator(*args, **kwargs)
|
252
|
+
state.add_result(result=result, new_input_hash=input_hash, overwrite=True)
|
253
|
+
return result
|
254
|
+
else:
|
255
|
+
if not isinstance(state.conversation_state.messages[-1], GooseResponse):
|
256
|
+
raise Honk(
|
257
|
+
"Conversation must alternate between User and Result messages"
|
258
|
+
)
|
259
|
+
return state.result
|
260
|
+
|
261
|
+
async def adapt(
|
262
|
+
self, *, flow_state: FlowState, user_message: UserMessage, index: int = 0
|
263
|
+
) -> R:
|
264
|
+
node_state = flow_state.get(task=self, index=index)
|
265
|
+
if self._adapter is None:
|
266
|
+
raise Honk("No adapter provided for Task")
|
267
|
+
|
268
|
+
node_state.add_user_message(message=user_message)
|
269
|
+
result = await self._adapter(conversation_state=node_state.conversation_state)
|
270
|
+
node_state.add_result(result=result)
|
271
|
+
flow_state.add(node_state)
|
272
|
+
|
273
|
+
return result
|
274
|
+
|
275
|
+
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
276
|
+
flow_state = self.__get_current_flow_state()
|
277
|
+
node_state = flow_state.get_next(task=self)
|
278
|
+
result = await self.generate(node_state, *args, **kwargs)
|
279
|
+
flow_state.add(node_state)
|
280
|
+
return result
|
281
|
+
|
282
|
+
def __hash_input(self, *args: P.args, **kwargs: P.kwargs) -> int:
|
283
|
+
try:
|
284
|
+
to_hash = str(tuple(args) + tuple(kwargs.values()))
|
285
|
+
return hash(to_hash)
|
286
|
+
except TypeError:
|
287
|
+
raise Honk(f"Unhashable argument to task {self.name}: {args} {kwargs}")
|
288
|
+
|
289
|
+
def __get_current_flow_state(self) -> FlowState:
|
290
|
+
state = _current_flow_state.get()
|
291
|
+
if state is None:
|
292
|
+
raise Honk("No current flow state")
|
293
|
+
return state
|
294
|
+
|
295
|
+
|
296
|
+
@overload
|
297
|
+
def task[**P, R: Result](generator: Callable[P, Awaitable[R]], /) -> Task[P, R]: ...
|
298
|
+
@overload
|
299
|
+
def task[**P, R: Result](
|
300
|
+
*, retries: int = 0
|
301
|
+
) -> Callable[[Callable[P, Awaitable[R]]], Task[P, R]]: ...
|
302
|
+
def task[**P, R: Result](
|
303
|
+
generator: Callable[P, Awaitable[R]] | None = None,
|
304
|
+
/,
|
305
|
+
*,
|
306
|
+
retries: int = 0,
|
307
|
+
) -> Task[P, R] | Callable[[Callable[P, Awaitable[R]]], Task[P, R]]:
|
308
|
+
if generator is None:
|
309
|
+
|
310
|
+
def decorator(fn: Callable[P, Awaitable[R]]) -> Task[P, R]:
|
311
|
+
return Task(fn, retries=retries)
|
312
|
+
|
313
|
+
return decorator
|
314
|
+
|
315
|
+
return Task(generator, retries=retries)
|
316
|
+
|
317
|
+
|
318
|
+
@overload
|
319
|
+
def flow[**P](fn: Callable[P, Awaitable[None]], /) -> Flow[P]: ...
|
320
|
+
@overload
|
321
|
+
def flow[**P](
|
322
|
+
*, name: str | None = None
|
323
|
+
) -> Callable[[Callable[P, Awaitable[None]]], Flow[P]]: ...
|
324
|
+
def flow[**P](
|
325
|
+
fn: Callable[P, Awaitable[None]] | None = None, /, *, name: str | None = None
|
326
|
+
) -> Flow[P] | Callable[[Callable[P, Awaitable[None]]], Flow[P]]:
|
327
|
+
if fn is None:
|
328
|
+
|
329
|
+
def decorator(fn: Callable[P, Awaitable[None]]) -> Flow[P]:
|
330
|
+
return Flow(fn, name=name)
|
331
|
+
|
332
|
+
return decorator
|
333
|
+
|
334
|
+
return Flow(fn, name=name)
|
goose_py-0.1.4/goose/__init__.py
DELETED
File without changes
|
goose_py-0.1.4/goose/agent.py
DELETED
@@ -1,75 +0,0 @@
|
|
1
|
-
import logging
|
2
|
-
from datetime import datetime
|
3
|
-
from typing import Any, Callable
|
4
|
-
|
5
|
-
from litellm import acompletion
|
6
|
-
from pydantic import BaseModel
|
7
|
-
|
8
|
-
from goose.types import (
|
9
|
-
AgentResponse,
|
10
|
-
AssistantMessage,
|
11
|
-
GeminiModel,
|
12
|
-
SystemMessage,
|
13
|
-
UserMessage,
|
14
|
-
)
|
15
|
-
|
16
|
-
|
17
|
-
class Agent:
|
18
|
-
def __init__(
|
19
|
-
self,
|
20
|
-
*,
|
21
|
-
flow_name: str,
|
22
|
-
run_id: str,
|
23
|
-
logger: Callable[[AgentResponse[Any]], None] | None = None,
|
24
|
-
) -> None:
|
25
|
-
self.flow_name = flow_name
|
26
|
-
self.run_id = run_id
|
27
|
-
self.logger = logger or logging.info
|
28
|
-
|
29
|
-
async def __call__[R: BaseModel](
|
30
|
-
self,
|
31
|
-
*,
|
32
|
-
messages: list[UserMessage | AssistantMessage],
|
33
|
-
model: GeminiModel,
|
34
|
-
response_model: type[R],
|
35
|
-
task_name: str,
|
36
|
-
system: SystemMessage | None = None,
|
37
|
-
) -> R:
|
38
|
-
start_time = datetime.now()
|
39
|
-
rendered_messages = [message.render() for message in messages]
|
40
|
-
if system is not None:
|
41
|
-
rendered_messages.insert(0, system.render())
|
42
|
-
|
43
|
-
response = await acompletion(
|
44
|
-
model=model.value,
|
45
|
-
messages=rendered_messages,
|
46
|
-
response_format={
|
47
|
-
"type": "json_object",
|
48
|
-
"response_schema": response_model.model_json_schema(),
|
49
|
-
"enforce_validation": True,
|
50
|
-
},
|
51
|
-
)
|
52
|
-
|
53
|
-
if len(response.choices) == 0:
|
54
|
-
raise RuntimeError("No content returned from LLM call.")
|
55
|
-
|
56
|
-
parsed_response = response_model.model_validate_json(
|
57
|
-
response.choices[0].message.content
|
58
|
-
)
|
59
|
-
end_time = datetime.now()
|
60
|
-
agent_response = AgentResponse(
|
61
|
-
response=parsed_response,
|
62
|
-
id=self.run_id,
|
63
|
-
flow_name=self.flow_name,
|
64
|
-
task_name=task_name,
|
65
|
-
model=model,
|
66
|
-
system=system,
|
67
|
-
input_messages=messages,
|
68
|
-
input_tokens=response.usage.prompt_tokens,
|
69
|
-
output_tokens=response.usage.completion_tokens,
|
70
|
-
start_time=start_time,
|
71
|
-
end_time=end_time,
|
72
|
-
)
|
73
|
-
|
74
|
-
self.logger(agent_response)
|
75
|
-
return agent_response.response
|
@@ -1,45 +0,0 @@
|
|
1
|
-
from pydantic import BaseModel
|
2
|
-
|
3
|
-
from goose.types import UserMessage
|
4
|
-
|
5
|
-
|
6
|
-
class ConversationState[R: BaseModel](BaseModel):
|
7
|
-
user_messages: list[UserMessage]
|
8
|
-
results: list[R]
|
9
|
-
|
10
|
-
|
11
|
-
class Conversation[R: BaseModel]:
|
12
|
-
def __init__(
|
13
|
-
self,
|
14
|
-
*,
|
15
|
-
user_messages: list[UserMessage] | None = None,
|
16
|
-
results: list[R] | None = None,
|
17
|
-
) -> None:
|
18
|
-
self.user_messages = user_messages or []
|
19
|
-
self.results = results or []
|
20
|
-
|
21
|
-
@classmethod
|
22
|
-
def load(cls, *, state: ConversationState[R]) -> "Conversation[R]":
|
23
|
-
return cls(user_messages=state.user_messages, results=state.results)
|
24
|
-
|
25
|
-
@property
|
26
|
-
def current_result(self) -> R:
|
27
|
-
if len(self.results) == 0:
|
28
|
-
raise RuntimeError("No results in conversation")
|
29
|
-
|
30
|
-
return self.results[-1]
|
31
|
-
|
32
|
-
def add_message(self, *, message: UserMessage) -> None:
|
33
|
-
self.user_messages.append(message)
|
34
|
-
|
35
|
-
def add_result(self, *, result: R) -> None:
|
36
|
-
self.results.append(result)
|
37
|
-
|
38
|
-
def replace_last_result(self, *, result: R) -> None:
|
39
|
-
if len(self.results) == 0:
|
40
|
-
self.results.append(result)
|
41
|
-
else:
|
42
|
-
self.results[-1] = result
|
43
|
-
|
44
|
-
def dump(self) -> ConversationState[R]:
|
45
|
-
return ConversationState(user_messages=self.user_messages, results=self.results)
|
goose_py-0.1.4/goose/core.py
DELETED
@@ -1,294 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
import contextvars
|
3
|
-
import inspect
|
4
|
-
from collections import defaultdict
|
5
|
-
from types import TracebackType
|
6
|
-
from typing import Any, Awaitable, Callable, Protocol, Self, overload
|
7
|
-
|
8
|
-
from graphlib import TopologicalSorter
|
9
|
-
from pydantic import BaseModel
|
10
|
-
|
11
|
-
from goose.agent import Agent
|
12
|
-
from goose.errors import Honk
|
13
|
-
from goose.conversation import Conversation, ConversationState
|
14
|
-
from goose.types import AgentResponse, UserMessage
|
15
|
-
|
16
|
-
|
17
|
-
class NodeState[R: BaseModel](BaseModel):
|
18
|
-
name: str
|
19
|
-
conversation: ConversationState[R]
|
20
|
-
|
21
|
-
@property
|
22
|
-
def result(self) -> R:
|
23
|
-
return self.conversation.results[-1]
|
24
|
-
|
25
|
-
|
26
|
-
class FlowState(BaseModel):
|
27
|
-
nodes: list[NodeState[BaseModel]]
|
28
|
-
|
29
|
-
|
30
|
-
class NoResult:
|
31
|
-
pass
|
32
|
-
|
33
|
-
|
34
|
-
class IRegenerator[R: BaseModel](Protocol):
|
35
|
-
async def __call__(self, *, result: R, conversation: Conversation[R]) -> R: ...
|
36
|
-
|
37
|
-
|
38
|
-
class Task[**P, R: BaseModel]:
|
39
|
-
def __init__(
|
40
|
-
self, generator: Callable[P, Awaitable[R]], /, *, retries: int = 0
|
41
|
-
) -> None:
|
42
|
-
self.retries = retries
|
43
|
-
self._generator = generator
|
44
|
-
self._regenerator: IRegenerator[R] | None = None
|
45
|
-
self._signature = inspect.signature(generator)
|
46
|
-
self.__validate_fn()
|
47
|
-
|
48
|
-
@property
|
49
|
-
def result_type(self) -> type[R]:
|
50
|
-
return_type = self._generator.__annotations__.get("return")
|
51
|
-
if return_type is None:
|
52
|
-
raise Honk("Task must have a return type annotation")
|
53
|
-
|
54
|
-
return return_type
|
55
|
-
|
56
|
-
@property
|
57
|
-
def name(self) -> str:
|
58
|
-
return self._generator.__name__
|
59
|
-
|
60
|
-
def regenerator(self, regenerator: IRegenerator[R], /) -> Self:
|
61
|
-
self._regenerator = regenerator
|
62
|
-
return self
|
63
|
-
|
64
|
-
async def generate(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
65
|
-
return await self._generator(*args, **kwargs)
|
66
|
-
|
67
|
-
async def regenerate(self, *, result: R, conversation: Conversation[R]) -> R:
|
68
|
-
if self._regenerator is None:
|
69
|
-
raise Honk("Task does not have a regenerator implemented")
|
70
|
-
|
71
|
-
return await self._regenerator(result=result, conversation=conversation)
|
72
|
-
|
73
|
-
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> "Node[R]":
|
74
|
-
arguments = self._signature.bind(*args, **kwargs).arguments
|
75
|
-
return Node(task=self, arguments=arguments, result_type=self.result_type)
|
76
|
-
|
77
|
-
def __validate_fn(self) -> None:
|
78
|
-
if any(
|
79
|
-
param.kind == inspect.Parameter.POSITIONAL_ONLY
|
80
|
-
for param in self._signature.parameters.values()
|
81
|
-
):
|
82
|
-
raise Honk("Positional-only parameters are not supported in Tasks")
|
83
|
-
|
84
|
-
|
85
|
-
class Node[R: BaseModel]:
|
86
|
-
def __init__(
|
87
|
-
self,
|
88
|
-
*,
|
89
|
-
task: Task[Any, R],
|
90
|
-
arguments: dict[str, Any],
|
91
|
-
result_type: type[R],
|
92
|
-
conversation: Conversation[R] | None = None,
|
93
|
-
) -> None:
|
94
|
-
self.task = task
|
95
|
-
self.arguments = arguments
|
96
|
-
self.result_type = result_type
|
97
|
-
self.conversation = conversation or Conversation[R]()
|
98
|
-
self.name = task.name
|
99
|
-
|
100
|
-
self._result: R | NoResult = NoResult()
|
101
|
-
current_flow = Flow.get_current()
|
102
|
-
if current_flow is None:
|
103
|
-
raise Honk("Cannot create a node without an active flow")
|
104
|
-
self.id = current_flow.add_node(node=self)
|
105
|
-
|
106
|
-
@property
|
107
|
-
def has_result(self) -> bool:
|
108
|
-
return not isinstance(self._result, NoResult)
|
109
|
-
|
110
|
-
@property
|
111
|
-
def result(self) -> R:
|
112
|
-
if isinstance(self._result, NoResult):
|
113
|
-
raise Honk("Cannot access result of a node before it has run")
|
114
|
-
return self._result
|
115
|
-
|
116
|
-
async def generate(self) -> None:
|
117
|
-
self._result = await self.task.generate(**self.arguments)
|
118
|
-
self.conversation.replace_last_result(result=self.result)
|
119
|
-
|
120
|
-
async def regenerate(self, *, message: UserMessage) -> None:
|
121
|
-
self.conversation.add_message(message=message)
|
122
|
-
self._result = await self.task.regenerate(
|
123
|
-
result=self.result, conversation=self.conversation
|
124
|
-
)
|
125
|
-
self.conversation.add_result(result=self.result)
|
126
|
-
|
127
|
-
def dump_state(self) -> NodeState[R]:
|
128
|
-
return NodeState(name=self.name, conversation=self.conversation.dump())
|
129
|
-
|
130
|
-
def load_state(self, *, state: NodeState[R]) -> None:
|
131
|
-
self._result = state.result
|
132
|
-
self.conversation = Conversation[self.result_type].load(
|
133
|
-
state=state.conversation
|
134
|
-
)
|
135
|
-
|
136
|
-
def get_inbound_nodes(self) -> list["Node[BaseModel]"]:
|
137
|
-
def __find_nodes(
|
138
|
-
obj: Any, visited: set[int] | None = None
|
139
|
-
) -> list["Node[BaseModel]"]:
|
140
|
-
if visited is None:
|
141
|
-
visited = set()
|
142
|
-
|
143
|
-
if isinstance(obj, Node):
|
144
|
-
return [obj]
|
145
|
-
elif isinstance(obj, dict):
|
146
|
-
return [
|
147
|
-
node
|
148
|
-
for value in obj.values()
|
149
|
-
for node in __find_nodes(value, visited)
|
150
|
-
]
|
151
|
-
elif isinstance(obj, list):
|
152
|
-
return [node for item in obj for node in __find_nodes(item, visited)]
|
153
|
-
elif isinstance(obj, tuple):
|
154
|
-
return [node for item in obj for node in __find_nodes(item, visited)]
|
155
|
-
elif isinstance(obj, set):
|
156
|
-
return [node for item in obj for node in __find_nodes(item, visited)]
|
157
|
-
elif hasattr(obj, "__dict__"):
|
158
|
-
return [
|
159
|
-
node
|
160
|
-
for value in obj.__dict__.values()
|
161
|
-
for node in __find_nodes(value, visited)
|
162
|
-
]
|
163
|
-
return []
|
164
|
-
|
165
|
-
return __find_nodes(self.arguments)
|
166
|
-
|
167
|
-
def __hash__(self) -> int:
|
168
|
-
return hash(self.id)
|
169
|
-
|
170
|
-
|
171
|
-
class Flow:
|
172
|
-
_current: contextvars.ContextVar["Flow | None"] = contextvars.ContextVar(
|
173
|
-
"current_flow", default=None
|
174
|
-
)
|
175
|
-
|
176
|
-
def __init__(
|
177
|
-
self,
|
178
|
-
*,
|
179
|
-
name: str,
|
180
|
-
run_id: str,
|
181
|
-
agent_logger: Callable[[AgentResponse[Any]], None] | None = None,
|
182
|
-
) -> None:
|
183
|
-
self.name = name
|
184
|
-
self._nodes: list[Node[BaseModel]] = []
|
185
|
-
self._agent = Agent(flow_name=self.name, run_id=run_id, logger=agent_logger)
|
186
|
-
|
187
|
-
@property
|
188
|
-
def agent(self) -> Agent:
|
189
|
-
return self._agent
|
190
|
-
|
191
|
-
def dump_state(self) -> FlowState:
|
192
|
-
return FlowState(nodes=[node.dump_state() for node in self._nodes])
|
193
|
-
|
194
|
-
def load_state(self, *, flow_state: FlowState) -> None:
|
195
|
-
nodes_by_name = {node.name: node for node in self._nodes}
|
196
|
-
for node_state in flow_state.nodes:
|
197
|
-
matching_node = nodes_by_name.get(node_state.name)
|
198
|
-
if matching_node is None:
|
199
|
-
raise Honk(f"Node {node_state.name} from state not found in flow")
|
200
|
-
|
201
|
-
matching_node.load_state(state=node_state)
|
202
|
-
|
203
|
-
async def generate(self) -> None:
|
204
|
-
graph = {node: node.get_inbound_nodes() for node in self._nodes}
|
205
|
-
sorter = TopologicalSorter(graph)
|
206
|
-
sorter.prepare()
|
207
|
-
|
208
|
-
async with asyncio.TaskGroup() as task_group:
|
209
|
-
while sorter.is_active():
|
210
|
-
ready_nodes = list(sorter.get_ready())
|
211
|
-
if ready_nodes:
|
212
|
-
for node in ready_nodes:
|
213
|
-
task_group.create_task(node.generate())
|
214
|
-
sorter.done(*ready_nodes)
|
215
|
-
else:
|
216
|
-
await asyncio.sleep(0)
|
217
|
-
|
218
|
-
async def regenerate(self, *, target: Node[Any], message: UserMessage) -> None:
|
219
|
-
if not target.has_result:
|
220
|
-
raise Honk("Cannot regenerate a node without a result")
|
221
|
-
|
222
|
-
await target.regenerate(message=message)
|
223
|
-
|
224
|
-
# regenerate all downstream nodes
|
225
|
-
full_graph = {node: node.get_inbound_nodes() for node in self._nodes}
|
226
|
-
reversed_graph: dict[Node[BaseModel], set[Node[BaseModel]]] = defaultdict(set)
|
227
|
-
for node, inbound_nodes in full_graph.items():
|
228
|
-
for inbound_node in inbound_nodes:
|
229
|
-
reversed_graph[inbound_node].add(node)
|
230
|
-
|
231
|
-
subgraph: dict[Node[BaseModel], set[Node[BaseModel]]] = defaultdict(set)
|
232
|
-
queue: list[Node[BaseModel]] = [target]
|
233
|
-
|
234
|
-
while len(queue) > 0:
|
235
|
-
node = queue.pop(0)
|
236
|
-
outbound_nodes = reversed_graph[node]
|
237
|
-
for outbound_node in outbound_nodes:
|
238
|
-
subgraph[outbound_node].add(node)
|
239
|
-
if outbound_node not in subgraph:
|
240
|
-
queue.append(outbound_node)
|
241
|
-
|
242
|
-
if len(subgraph) > 0:
|
243
|
-
sorter = TopologicalSorter(subgraph)
|
244
|
-
sorter.prepare()
|
245
|
-
|
246
|
-
async with asyncio.TaskGroup() as task_group:
|
247
|
-
while sorter.is_active():
|
248
|
-
ready_nodes = list(sorter.get_ready())
|
249
|
-
if len(ready_nodes) > 0:
|
250
|
-
for node in ready_nodes:
|
251
|
-
if node != target:
|
252
|
-
task_group.create_task(node.generate())
|
253
|
-
sorter.done(*ready_nodes)
|
254
|
-
else:
|
255
|
-
await asyncio.sleep(0)
|
256
|
-
|
257
|
-
@classmethod
|
258
|
-
def get_current(cls) -> "Flow | None":
|
259
|
-
return cls._current.get()
|
260
|
-
|
261
|
-
def add_node(self, *, node: Node[Any]) -> str:
|
262
|
-
existing_names = [node.name for node in self._nodes]
|
263
|
-
number = sum(1 for name in existing_names if name == node.name)
|
264
|
-
self._nodes.append(node)
|
265
|
-
node_id = f"{node.name}_{number}"
|
266
|
-
return node_id
|
267
|
-
|
268
|
-
def __enter__(self) -> Self:
|
269
|
-
if self._current.get() is not None:
|
270
|
-
raise Honk("Cannot enter a new flow while another flow is already active")
|
271
|
-
self._current.set(self)
|
272
|
-
return self
|
273
|
-
|
274
|
-
def __exit__(
|
275
|
-
self,
|
276
|
-
exc_type: type[BaseException] | None,
|
277
|
-
exc_value: BaseException | None,
|
278
|
-
traceback: TracebackType | None,
|
279
|
-
) -> None:
|
280
|
-
self._current.set(None)
|
281
|
-
|
282
|
-
|
283
|
-
@overload
|
284
|
-
def task[**P, R: BaseModel](fn: Callable[P, Awaitable[R]], /) -> Task[P, R]: ...
|
285
|
-
@overload
|
286
|
-
def task[**P, R: BaseModel](
|
287
|
-
*, retries: int = 0
|
288
|
-
) -> Callable[[Callable[P, Awaitable[R]]], Task[P, R]]: ...
|
289
|
-
def task[**P, R: BaseModel](
|
290
|
-
fn: Callable[P, Awaitable[R]] | None = None, /, *, retries: int = 0
|
291
|
-
) -> Task[P, R] | Callable[[Callable[P, Awaitable[R]]], Task[P, R]]:
|
292
|
-
if fn is None:
|
293
|
-
return lambda fn: Task(fn, retries=retries)
|
294
|
-
return Task(fn, retries=retries)
|
File without changes
|
File without changes
|