goose-py 0.5.0__tar.gz → 0.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {goose_py-0.5.0 → goose_py-0.6.0}/PKG-INFO +1 -1
- goose_py-0.6.0/goose/__init__.py +5 -0
- goose_py-0.5.0/goose/agent.py → goose_py-0.6.0/goose/_agent.py +3 -91
- goose_py-0.6.0/goose/_conversation.py +36 -0
- goose_py-0.6.0/goose/_flow.py +106 -0
- goose_py-0.6.0/goose/_state.py +190 -0
- goose_py-0.5.0/goose/store.py → goose_py-0.6.0/goose/_store.py +8 -6
- goose_py-0.6.0/goose/_task.py +136 -0
- goose_py-0.6.0/goose/types/agent.py +92 -0
- {goose_py-0.5.0 → goose_py-0.6.0}/pyproject.toml +5 -1
- goose_py-0.5.0/goose/flow.py +0 -458
- {goose_py-0.5.0 → goose_py-0.6.0}/README.md +0 -0
- /goose_py-0.5.0/goose/result.py → /goose_py-0.6.0/goose/_result.py +0 -0
- {goose_py-0.5.0 → goose_py-0.6.0}/goose/errors.py +0 -0
- {goose_py-0.5.0 → goose_py-0.6.0}/goose/py.typed +0 -0
- {goose_py-0.5.0/goose → goose_py-0.6.0/goose/types}/__init__.py +0 -0
@@ -1,101 +1,13 @@
|
|
1
|
-
import base64
|
2
1
|
import json
|
3
2
|
import logging
|
4
3
|
from datetime import datetime
|
5
|
-
from
|
6
|
-
from typing import Any, ClassVar, Literal, NotRequired, Protocol, TypedDict
|
4
|
+
from typing import Any, ClassVar, Protocol, TypedDict
|
7
5
|
|
8
6
|
from litellm import acompletion
|
9
7
|
from pydantic import BaseModel, computed_field
|
10
|
-
from goose.result import Result, TextResult
|
11
8
|
|
12
|
-
|
13
|
-
|
14
|
-
PRO = "vertex_ai/gemini-1.5-pro"
|
15
|
-
FLASH = "vertex_ai/gemini-1.5-flash"
|
16
|
-
FLASH_8B = "vertex_ai/gemini-1.5-flash-8b"
|
17
|
-
|
18
|
-
|
19
|
-
class UserMediaContentType(StrEnum):
|
20
|
-
# images
|
21
|
-
JPEG = "image/jpeg"
|
22
|
-
PNG = "image/png"
|
23
|
-
WEBP = "image/webp"
|
24
|
-
|
25
|
-
# audio
|
26
|
-
MP3 = "audio/mp3"
|
27
|
-
WAV = "audio/wav"
|
28
|
-
|
29
|
-
# files
|
30
|
-
PDF = "application/pdf"
|
31
|
-
|
32
|
-
|
33
|
-
class LLMTextMessagePart(TypedDict):
|
34
|
-
type: Literal["text"]
|
35
|
-
text: str
|
36
|
-
|
37
|
-
|
38
|
-
class LLMMediaMessagePart(TypedDict):
|
39
|
-
type: Literal["image_url"]
|
40
|
-
image_url: str
|
41
|
-
|
42
|
-
|
43
|
-
class CacheControl(TypedDict):
|
44
|
-
type: Literal["ephemeral"]
|
45
|
-
|
46
|
-
|
47
|
-
class LLMMessage(TypedDict):
|
48
|
-
role: Literal["user", "assistant", "system"]
|
49
|
-
content: list[LLMTextMessagePart | LLMMediaMessagePart]
|
50
|
-
cache_control: NotRequired[CacheControl]
|
51
|
-
|
52
|
-
|
53
|
-
class TextMessagePart(BaseModel):
|
54
|
-
text: str
|
55
|
-
|
56
|
-
def render(self) -> LLMTextMessagePart:
|
57
|
-
return {"type": "text", "text": self.text}
|
58
|
-
|
59
|
-
|
60
|
-
class MediaMessagePart(BaseModel):
|
61
|
-
content_type: UserMediaContentType
|
62
|
-
content: bytes
|
63
|
-
|
64
|
-
def render(self) -> LLMMediaMessagePart:
|
65
|
-
return {
|
66
|
-
"type": "image_url",
|
67
|
-
"image_url": f"data:{self.content_type};base64,{base64.b64encode(self.content).decode()}",
|
68
|
-
}
|
69
|
-
|
70
|
-
|
71
|
-
class UserMessage(BaseModel):
|
72
|
-
parts: list[TextMessagePart | MediaMessagePart]
|
73
|
-
|
74
|
-
def render(self) -> LLMMessage:
|
75
|
-
content: LLMMessage = {
|
76
|
-
"role": "user",
|
77
|
-
"content": [part.render() for part in self.parts],
|
78
|
-
}
|
79
|
-
if any(isinstance(part, MediaMessagePart) for part in self.parts):
|
80
|
-
content["cache_control"] = {"type": "ephemeral"}
|
81
|
-
return content
|
82
|
-
|
83
|
-
|
84
|
-
class AssistantMessage(BaseModel):
|
85
|
-
text: str
|
86
|
-
|
87
|
-
def render(self) -> LLMMessage:
|
88
|
-
return {"role": "assistant", "content": [{"type": "text", "text": self.text}]}
|
89
|
-
|
90
|
-
|
91
|
-
class SystemMessage(BaseModel):
|
92
|
-
parts: list[TextMessagePart | MediaMessagePart]
|
93
|
-
|
94
|
-
def render(self) -> LLMMessage:
|
95
|
-
return {
|
96
|
-
"role": "system",
|
97
|
-
"content": [part.render() for part in self.parts],
|
98
|
-
}
|
9
|
+
from goose._result import Result, TextResult
|
10
|
+
from goose.types.agent import AssistantMessage, GeminiModel, SystemMessage, UserMessage
|
99
11
|
|
100
12
|
|
101
13
|
class AgentResponseDump(TypedDict):
|
@@ -0,0 +1,36 @@
|
|
1
|
+
from pydantic import BaseModel
|
2
|
+
|
3
|
+
from goose._result import Result
|
4
|
+
from goose.types.agent import AssistantMessage, LLMMessage, SystemMessage, UserMessage
|
5
|
+
|
6
|
+
|
7
|
+
class Conversation[R: Result](BaseModel):
|
8
|
+
user_messages: list[UserMessage]
|
9
|
+
result_messages: list[R]
|
10
|
+
context: SystemMessage | None = None
|
11
|
+
|
12
|
+
@property
|
13
|
+
def awaiting_response(self) -> bool:
|
14
|
+
return len(self.user_messages) == len(self.result_messages)
|
15
|
+
|
16
|
+
def render(self) -> list[LLMMessage]:
|
17
|
+
messages: list[LLMMessage] = []
|
18
|
+
if self.context is not None:
|
19
|
+
messages.append(self.context.render())
|
20
|
+
|
21
|
+
for message_index in range(len(self.user_messages)):
|
22
|
+
messages.append(
|
23
|
+
AssistantMessage(
|
24
|
+
text=self.result_messages[message_index].model_dump_json()
|
25
|
+
).render()
|
26
|
+
)
|
27
|
+
messages.append(self.user_messages[message_index].render())
|
28
|
+
|
29
|
+
if len(self.result_messages) > len(self.user_messages):
|
30
|
+
messages.append(
|
31
|
+
AssistantMessage(
|
32
|
+
text=self.result_messages[-1].model_dump_json()
|
33
|
+
).render()
|
34
|
+
)
|
35
|
+
|
36
|
+
return messages
|
@@ -0,0 +1,106 @@
|
|
1
|
+
from contextlib import asynccontextmanager
|
2
|
+
from types import CodeType
|
3
|
+
from typing import AsyncIterator, Awaitable, Callable, Protocol, overload
|
4
|
+
|
5
|
+
from goose._agent import Agent, IAgentLogger
|
6
|
+
from goose._conversation import Conversation
|
7
|
+
from goose._result import Result
|
8
|
+
from goose._state import FlowRun, get_current_flow_run, set_current_flow_run
|
9
|
+
from goose._store import IFlowRunStore, InMemoryFlowRunStore
|
10
|
+
from goose.errors import Honk
|
11
|
+
|
12
|
+
|
13
|
+
class IAdapter[ResultT: Result](Protocol):
|
14
|
+
__code__: CodeType
|
15
|
+
|
16
|
+
async def __call__(
|
17
|
+
self, *, conversation: Conversation[ResultT], agent: Agent
|
18
|
+
) -> ResultT: ...
|
19
|
+
|
20
|
+
|
21
|
+
class Flow[**P]:
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
fn: Callable[P, Awaitable[None]],
|
25
|
+
/,
|
26
|
+
*,
|
27
|
+
name: str | None = None,
|
28
|
+
store: IFlowRunStore | None = None,
|
29
|
+
agent_logger: IAgentLogger | None = None,
|
30
|
+
) -> None:
|
31
|
+
self._fn = fn
|
32
|
+
self._name = name
|
33
|
+
self._agent_logger = agent_logger
|
34
|
+
self._store = store or InMemoryFlowRunStore(flow_name=self.name)
|
35
|
+
|
36
|
+
@property
|
37
|
+
def name(self) -> str:
|
38
|
+
return self._name or self._fn.__name__
|
39
|
+
|
40
|
+
@property
|
41
|
+
def current_run(self) -> FlowRun:
|
42
|
+
run = get_current_flow_run()
|
43
|
+
if run is None:
|
44
|
+
raise Honk("No current flow run")
|
45
|
+
return run
|
46
|
+
|
47
|
+
@asynccontextmanager
|
48
|
+
async def start_run(self, *, run_id: str) -> AsyncIterator[FlowRun]:
|
49
|
+
existing_run = await self._store.get(run_id=run_id)
|
50
|
+
if existing_run is None:
|
51
|
+
run = FlowRun()
|
52
|
+
else:
|
53
|
+
run = existing_run
|
54
|
+
|
55
|
+
old_run = get_current_flow_run()
|
56
|
+
set_current_flow_run(run)
|
57
|
+
|
58
|
+
run.start(flow_name=self.name, run_id=run_id, agent_logger=self._agent_logger)
|
59
|
+
yield run
|
60
|
+
await self._store.save(run=run)
|
61
|
+
run.end()
|
62
|
+
|
63
|
+
set_current_flow_run(old_run)
|
64
|
+
|
65
|
+
async def generate(self, *args: P.args, **kwargs: P.kwargs) -> None:
|
66
|
+
flow_run = get_current_flow_run()
|
67
|
+
if flow_run is None:
|
68
|
+
raise Honk("No current flow run")
|
69
|
+
|
70
|
+
flow_run.set_flow_inputs(*args, **kwargs)
|
71
|
+
await self._fn(*args, **kwargs)
|
72
|
+
|
73
|
+
async def regenerate(self) -> None:
|
74
|
+
flow_run = get_current_flow_run()
|
75
|
+
if flow_run is None:
|
76
|
+
raise Honk("No current flow run")
|
77
|
+
|
78
|
+
flow_args, flow_kwargs = flow_run.flow_inputs
|
79
|
+
await self._fn(*flow_args, **flow_kwargs)
|
80
|
+
|
81
|
+
|
82
|
+
@overload
|
83
|
+
def flow[**P](fn: Callable[P, Awaitable[None]], /) -> Flow[P]: ...
|
84
|
+
@overload
|
85
|
+
def flow[**P](
|
86
|
+
*,
|
87
|
+
name: str | None = None,
|
88
|
+
store: IFlowRunStore | None = None,
|
89
|
+
agent_logger: IAgentLogger | None = None,
|
90
|
+
) -> Callable[[Callable[P, Awaitable[None]]], Flow[P]]: ...
|
91
|
+
def flow[**P](
|
92
|
+
fn: Callable[P, Awaitable[None]] | None = None,
|
93
|
+
/,
|
94
|
+
*,
|
95
|
+
name: str | None = None,
|
96
|
+
store: IFlowRunStore | None = None,
|
97
|
+
agent_logger: IAgentLogger | None = None,
|
98
|
+
) -> Flow[P] | Callable[[Callable[P, Awaitable[None]]], Flow[P]]:
|
99
|
+
if fn is None:
|
100
|
+
|
101
|
+
def decorator(fn: Callable[P, Awaitable[None]]) -> Flow[P]:
|
102
|
+
return Flow(fn, name=name, store=store, agent_logger=agent_logger)
|
103
|
+
|
104
|
+
return decorator
|
105
|
+
|
106
|
+
return Flow(fn, name=name, store=store, agent_logger=agent_logger)
|
@@ -0,0 +1,190 @@
|
|
1
|
+
from contextvars import ContextVar
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from typing import TYPE_CHECKING, Any, Self
|
4
|
+
|
5
|
+
from pydantic import BaseModel
|
6
|
+
|
7
|
+
from goose._agent import (
|
8
|
+
Agent,
|
9
|
+
IAgentLogger,
|
10
|
+
SystemMessage,
|
11
|
+
UserMessage,
|
12
|
+
)
|
13
|
+
from goose._conversation import Conversation
|
14
|
+
from goose._result import Result
|
15
|
+
from goose.errors import Honk
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
from goose._task import Task
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass
|
22
|
+
class FlowRunState:
|
23
|
+
node_states: dict[tuple[str, int], str]
|
24
|
+
flow_args: tuple[Any, ...]
|
25
|
+
flow_kwargs: dict[str, Any]
|
26
|
+
|
27
|
+
|
28
|
+
class NodeState[ResultT: Result](BaseModel):
|
29
|
+
task_name: str
|
30
|
+
index: int
|
31
|
+
conversation: Conversation[ResultT]
|
32
|
+
last_hash: int
|
33
|
+
|
34
|
+
@property
|
35
|
+
def result(self) -> ResultT:
|
36
|
+
if len(self.conversation.result_messages) == 0:
|
37
|
+
raise Honk("Node awaiting response, has no result")
|
38
|
+
|
39
|
+
return self.conversation.result_messages[-1]
|
40
|
+
|
41
|
+
def set_context(self, *, context: SystemMessage) -> Self:
|
42
|
+
self.conversation.context = context
|
43
|
+
return self
|
44
|
+
|
45
|
+
def add_result(
|
46
|
+
self,
|
47
|
+
*,
|
48
|
+
result: ResultT,
|
49
|
+
new_hash: int | None = None,
|
50
|
+
overwrite: bool = False,
|
51
|
+
) -> Self:
|
52
|
+
if overwrite and len(self.conversation.result_messages) > 0:
|
53
|
+
self.conversation.result_messages[-1] = result
|
54
|
+
else:
|
55
|
+
self.conversation.result_messages.append(result)
|
56
|
+
if new_hash is not None:
|
57
|
+
self.last_hash = new_hash
|
58
|
+
return self
|
59
|
+
|
60
|
+
def add_user_message(self, *, message: UserMessage) -> Self:
|
61
|
+
self.conversation.user_messages.append(message)
|
62
|
+
return self
|
63
|
+
|
64
|
+
|
65
|
+
class FlowRun:
|
66
|
+
def __init__(self) -> None:
|
67
|
+
self._node_states: dict[tuple[str, int], str] = {}
|
68
|
+
self._last_requested_indices: dict[str, int] = {}
|
69
|
+
self._flow_name = ""
|
70
|
+
self._id = ""
|
71
|
+
self._agent: Agent | None = None
|
72
|
+
self._flow_args: tuple[Any, ...] | None = None
|
73
|
+
self._flow_kwargs: dict[str, Any] | None = None
|
74
|
+
|
75
|
+
@property
|
76
|
+
def flow_name(self) -> str:
|
77
|
+
return self._flow_name
|
78
|
+
|
79
|
+
@property
|
80
|
+
def id(self) -> str:
|
81
|
+
return self._id
|
82
|
+
|
83
|
+
@property
|
84
|
+
def agent(self) -> Agent:
|
85
|
+
if self._agent is None:
|
86
|
+
raise Honk("Agent is only accessible once a run is started")
|
87
|
+
return self._agent
|
88
|
+
|
89
|
+
@property
|
90
|
+
def flow_inputs(self) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
91
|
+
if self._flow_args is None or self._flow_kwargs is None:
|
92
|
+
raise Honk("This Flow run has not been executed before")
|
93
|
+
|
94
|
+
return self._flow_args, self._flow_kwargs
|
95
|
+
|
96
|
+
def get_all[R: Result](self, *, task: "Task[Any, R]") -> list[NodeState[R]]:
|
97
|
+
matching_nodes: list[NodeState[R]] = []
|
98
|
+
for key, node_state in self._node_states.items():
|
99
|
+
if key[0] == task.name:
|
100
|
+
matching_nodes.append(
|
101
|
+
NodeState[task.result_type].model_validate_json(node_state)
|
102
|
+
)
|
103
|
+
return sorted(matching_nodes, key=lambda node: node.index)
|
104
|
+
|
105
|
+
def get[R: Result](self, *, task: "Task[Any, R]", index: int = 0) -> NodeState[R]:
|
106
|
+
if (
|
107
|
+
existing_node_state := self._node_states.get((task.name, index))
|
108
|
+
) is not None:
|
109
|
+
return NodeState[task.result_type].model_validate_json(existing_node_state)
|
110
|
+
else:
|
111
|
+
return NodeState[task.result_type](
|
112
|
+
task_name=task.name,
|
113
|
+
index=index,
|
114
|
+
conversation=Conversation[task.result_type](
|
115
|
+
user_messages=[], result_messages=[]
|
116
|
+
),
|
117
|
+
last_hash=0,
|
118
|
+
)
|
119
|
+
|
120
|
+
def set_flow_inputs(self, *args: Any, **kwargs: Any) -> None:
|
121
|
+
self._flow_args = args
|
122
|
+
self._flow_kwargs = kwargs
|
123
|
+
|
124
|
+
def add_node_state(self, node_state: NodeState[Any], /) -> None:
|
125
|
+
key = (node_state.task_name, node_state.index)
|
126
|
+
self._node_states[key] = node_state.model_dump_json()
|
127
|
+
|
128
|
+
def get_next[R: Result](self, *, task: "Task[Any, R]") -> NodeState[R]:
|
129
|
+
if task.name not in self._last_requested_indices:
|
130
|
+
self._last_requested_indices[task.name] = 0
|
131
|
+
else:
|
132
|
+
self._last_requested_indices[task.name] += 1
|
133
|
+
|
134
|
+
return self.get(task=task, index=self._last_requested_indices[task.name])
|
135
|
+
|
136
|
+
def start(
|
137
|
+
self,
|
138
|
+
*,
|
139
|
+
flow_name: str,
|
140
|
+
run_id: str,
|
141
|
+
agent_logger: IAgentLogger | None = None,
|
142
|
+
) -> None:
|
143
|
+
self._last_requested_indices = {}
|
144
|
+
self._flow_name = flow_name
|
145
|
+
self._id = run_id
|
146
|
+
self._agent = Agent(
|
147
|
+
flow_name=self.flow_name, run_id=self.id, logger=agent_logger
|
148
|
+
)
|
149
|
+
|
150
|
+
def end(self) -> None:
|
151
|
+
self._last_requested_indices = {}
|
152
|
+
self._flow_name = ""
|
153
|
+
self._id = ""
|
154
|
+
self._agent = None
|
155
|
+
|
156
|
+
def clear_node(self, *, task: "Task[Any, Result]", index: int) -> None:
|
157
|
+
key = (task.name, index)
|
158
|
+
if key in self._node_states:
|
159
|
+
del self._node_states[key]
|
160
|
+
|
161
|
+
def dump(self) -> FlowRunState:
|
162
|
+
flow_args, flow_kwargs = self.flow_inputs
|
163
|
+
|
164
|
+
return FlowRunState(
|
165
|
+
node_states=self._node_states,
|
166
|
+
flow_args=flow_args,
|
167
|
+
flow_kwargs=flow_kwargs,
|
168
|
+
)
|
169
|
+
|
170
|
+
@classmethod
|
171
|
+
def load(cls, flow_run_state: FlowRunState, /) -> Self:
|
172
|
+
flow_run = cls()
|
173
|
+
flow_run._node_states = flow_run_state.node_states
|
174
|
+
flow_run._flow_args = flow_run_state.flow_args
|
175
|
+
flow_run._flow_kwargs = flow_run_state.flow_kwargs
|
176
|
+
|
177
|
+
return flow_run
|
178
|
+
|
179
|
+
|
180
|
+
_current_flow_run: ContextVar[FlowRun | None] = ContextVar(
|
181
|
+
"current_flow_run", default=None
|
182
|
+
)
|
183
|
+
|
184
|
+
|
185
|
+
def get_current_flow_run() -> FlowRun | None:
|
186
|
+
return _current_flow_run.get()
|
187
|
+
|
188
|
+
|
189
|
+
def set_current_flow_run(flow_run: FlowRun | None) -> None:
|
190
|
+
_current_flow_run.set(flow_run)
|
@@ -1,9 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import
|
3
|
+
from typing import Protocol
|
4
4
|
|
5
|
-
|
6
|
-
|
5
|
+
from goose._flow import FlowRun
|
6
|
+
from goose._state import FlowRunState
|
7
7
|
|
8
8
|
|
9
9
|
class IFlowRunStore(Protocol):
|
@@ -16,13 +16,15 @@ class IFlowRunStore(Protocol):
|
|
16
16
|
class InMemoryFlowRunStore(IFlowRunStore):
|
17
17
|
def __init__(self, *, flow_name: str) -> None:
|
18
18
|
self._flow_name = flow_name
|
19
|
-
self._runs: dict[str,
|
19
|
+
self._runs: dict[str, FlowRunState] = {}
|
20
20
|
|
21
21
|
async def get(self, *, run_id: str) -> FlowRun | None:
|
22
|
-
|
22
|
+
state = self._runs.get(run_id)
|
23
|
+
if state is not None:
|
24
|
+
return FlowRun.load(state)
|
23
25
|
|
24
26
|
async def save(self, *, run: FlowRun) -> None:
|
25
|
-
self._runs[run.id] = run
|
27
|
+
self._runs[run.id] = run.dump()
|
26
28
|
|
27
29
|
async def delete(self, *, run_id: str) -> None:
|
28
30
|
self._runs.pop(run_id, None)
|
@@ -0,0 +1,136 @@
|
|
1
|
+
from typing import Awaitable, Callable, overload
|
2
|
+
|
3
|
+
from goose._agent import Agent, GeminiModel, SystemMessage, UserMessage
|
4
|
+
from goose._conversation import Conversation
|
5
|
+
from goose._result import Result, TextResult
|
6
|
+
from goose._state import FlowRun, NodeState, get_current_flow_run
|
7
|
+
from goose.errors import Honk
|
8
|
+
from goose.types.agent import AssistantMessage
|
9
|
+
|
10
|
+
|
11
|
+
class Task[**P, R: Result]:
|
12
|
+
def __init__(
|
13
|
+
self,
|
14
|
+
generator: Callable[P, Awaitable[R]],
|
15
|
+
/,
|
16
|
+
*,
|
17
|
+
retries: int = 0,
|
18
|
+
adapter_model: GeminiModel = GeminiModel.FLASH,
|
19
|
+
) -> None:
|
20
|
+
self._generator = generator
|
21
|
+
self._retries = retries
|
22
|
+
self._adapter_model = adapter_model
|
23
|
+
self._adapter_model = adapter_model
|
24
|
+
|
25
|
+
@property
|
26
|
+
def result_type(self) -> type[R]:
|
27
|
+
result_type = self._generator.__annotations__.get("return")
|
28
|
+
if result_type is None:
|
29
|
+
raise Honk(f"Task {self.name} has no return type annotation")
|
30
|
+
return result_type
|
31
|
+
|
32
|
+
@property
|
33
|
+
def name(self) -> str:
|
34
|
+
return self._generator.__name__
|
35
|
+
|
36
|
+
async def generate(
|
37
|
+
self, state: NodeState[R], *args: P.args, **kwargs: P.kwargs
|
38
|
+
) -> R:
|
39
|
+
state_hash = self.__hash_task_call(*args, **kwargs)
|
40
|
+
if state_hash != state.last_hash:
|
41
|
+
result = await self._generator(*args, **kwargs)
|
42
|
+
state.add_result(result=result, new_hash=state_hash, overwrite=True)
|
43
|
+
return result
|
44
|
+
else:
|
45
|
+
return state.result
|
46
|
+
|
47
|
+
async def jam(
|
48
|
+
self,
|
49
|
+
*,
|
50
|
+
user_message: UserMessage,
|
51
|
+
context: SystemMessage | None = None,
|
52
|
+
index: int = 0,
|
53
|
+
) -> R:
|
54
|
+
flow_run = self.__get_current_flow_run()
|
55
|
+
node_state = flow_run.get(task=self, index=index)
|
56
|
+
|
57
|
+
if context is not None:
|
58
|
+
node_state.set_context(context=context)
|
59
|
+
node_state.add_user_message(message=user_message)
|
60
|
+
|
61
|
+
result = await self.__adapt(
|
62
|
+
conversation=node_state.conversation, agent=flow_run.agent
|
63
|
+
)
|
64
|
+
node_state.add_result(result=result)
|
65
|
+
flow_run.add_node_state(node_state)
|
66
|
+
|
67
|
+
return result
|
68
|
+
|
69
|
+
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
70
|
+
flow_run = self.__get_current_flow_run()
|
71
|
+
node_state = flow_run.get_next(task=self)
|
72
|
+
result = await self.generate(node_state, *args, **kwargs)
|
73
|
+
flow_run.add_node_state(node_state)
|
74
|
+
return result
|
75
|
+
|
76
|
+
async def __adapt(self, *, conversation: Conversation[R], agent: Agent) -> R:
|
77
|
+
messages: list[UserMessage | AssistantMessage] = []
|
78
|
+
for message_index in range(len(conversation.user_messages)):
|
79
|
+
user_message = conversation.user_messages[message_index]
|
80
|
+
result = conversation.result_messages[message_index]
|
81
|
+
|
82
|
+
if isinstance(result, TextResult):
|
83
|
+
assistant_text = result.text
|
84
|
+
else:
|
85
|
+
assistant_text = result.model_dump_json()
|
86
|
+
assistant_message = AssistantMessage(text=assistant_text)
|
87
|
+
messages.append(assistant_message)
|
88
|
+
messages.append(user_message)
|
89
|
+
|
90
|
+
return await agent(
|
91
|
+
messages=messages,
|
92
|
+
model=self._adapter_model,
|
93
|
+
task_name=f"adapt--{self.name}",
|
94
|
+
system=conversation.context,
|
95
|
+
response_model=self.result_type,
|
96
|
+
)
|
97
|
+
|
98
|
+
def __hash_task_call(self, *args: P.args, **kwargs: P.kwargs) -> int:
|
99
|
+
try:
|
100
|
+
to_hash = str(
|
101
|
+
tuple(args)
|
102
|
+
+ tuple(kwargs.values())
|
103
|
+
+ (self._generator.__code__, self._adapter_model)
|
104
|
+
)
|
105
|
+
return hash(to_hash)
|
106
|
+
except TypeError:
|
107
|
+
raise Honk(f"Unhashable argument to task {self.name}: {args} {kwargs}")
|
108
|
+
|
109
|
+
def __get_current_flow_run(self) -> FlowRun:
|
110
|
+
run = get_current_flow_run()
|
111
|
+
if run is None:
|
112
|
+
raise Honk("No current flow run")
|
113
|
+
return run
|
114
|
+
|
115
|
+
|
116
|
+
@overload
|
117
|
+
def task[**P, R: Result](generator: Callable[P, Awaitable[R]], /) -> Task[P, R]: ...
|
118
|
+
@overload
|
119
|
+
def task[**P, R: Result](
|
120
|
+
*, retries: int = 0, adapter_model: GeminiModel = GeminiModel.FLASH
|
121
|
+
) -> Callable[[Callable[P, Awaitable[R]]], Task[P, R]]: ...
|
122
|
+
def task[**P, R: Result](
|
123
|
+
generator: Callable[P, Awaitable[R]] | None = None,
|
124
|
+
/,
|
125
|
+
*,
|
126
|
+
retries: int = 0,
|
127
|
+
adapter_model: GeminiModel = GeminiModel.FLASH,
|
128
|
+
) -> Task[P, R] | Callable[[Callable[P, Awaitable[R]]], Task[P, R]]:
|
129
|
+
if generator is None:
|
130
|
+
|
131
|
+
def decorator(fn: Callable[P, Awaitable[R]]) -> Task[P, R]:
|
132
|
+
return Task(fn, retries=retries, adapter_model=adapter_model)
|
133
|
+
|
134
|
+
return decorator
|
135
|
+
|
136
|
+
return Task(generator, retries=retries, adapter_model=adapter_model)
|
@@ -0,0 +1,92 @@
|
|
1
|
+
from enum import StrEnum
|
2
|
+
from typing import Literal, NotRequired, TypedDict
|
3
|
+
|
4
|
+
from pydantic import BaseModel
|
5
|
+
|
6
|
+
|
7
|
+
class GeminiModel(StrEnum):
|
8
|
+
PRO = "vertex_ai/gemini-1.5-pro"
|
9
|
+
FLASH = "vertex_ai/gemini-1.5-flash"
|
10
|
+
FLASH_8B = "vertex_ai/gemini-1.5-flash-8b"
|
11
|
+
|
12
|
+
|
13
|
+
class UserMediaContentType(StrEnum):
|
14
|
+
# images
|
15
|
+
JPEG = "image/jpeg"
|
16
|
+
PNG = "image/png"
|
17
|
+
WEBP = "image/webp"
|
18
|
+
|
19
|
+
# audio
|
20
|
+
MP3 = "audio/mp3"
|
21
|
+
WAV = "audio/wav"
|
22
|
+
|
23
|
+
# files
|
24
|
+
PDF = "application/pdf"
|
25
|
+
|
26
|
+
|
27
|
+
class LLMTextMessagePart(TypedDict):
|
28
|
+
type: Literal["text"]
|
29
|
+
text: str
|
30
|
+
|
31
|
+
|
32
|
+
class LLMMediaMessagePart(TypedDict):
|
33
|
+
type: Literal["image_url"]
|
34
|
+
image_url: str
|
35
|
+
|
36
|
+
|
37
|
+
class CacheControl(TypedDict):
|
38
|
+
type: Literal["ephemeral"]
|
39
|
+
|
40
|
+
|
41
|
+
class LLMMessage(TypedDict):
|
42
|
+
role: Literal["user", "assistant", "system"]
|
43
|
+
content: list[LLMTextMessagePart | LLMMediaMessagePart]
|
44
|
+
cache_control: NotRequired[CacheControl]
|
45
|
+
|
46
|
+
|
47
|
+
class TextMessagePart(BaseModel):
|
48
|
+
text: str
|
49
|
+
|
50
|
+
def render(self) -> LLMTextMessagePart:
|
51
|
+
return {"type": "text", "text": self.text}
|
52
|
+
|
53
|
+
|
54
|
+
class MediaMessagePart(BaseModel):
|
55
|
+
content_type: UserMediaContentType
|
56
|
+
content: str
|
57
|
+
|
58
|
+
def render(self) -> LLMMediaMessagePart:
|
59
|
+
return {
|
60
|
+
"type": "image_url",
|
61
|
+
"image_url": f"data:{self.content_type};base64,{self.content}",
|
62
|
+
}
|
63
|
+
|
64
|
+
|
65
|
+
class UserMessage(BaseModel):
|
66
|
+
parts: list[TextMessagePart | MediaMessagePart]
|
67
|
+
|
68
|
+
def render(self) -> LLMMessage:
|
69
|
+
content: LLMMessage = {
|
70
|
+
"role": "user",
|
71
|
+
"content": [part.render() for part in self.parts],
|
72
|
+
}
|
73
|
+
if any(isinstance(part, MediaMessagePart) for part in self.parts):
|
74
|
+
content["cache_control"] = {"type": "ephemeral"}
|
75
|
+
return content
|
76
|
+
|
77
|
+
|
78
|
+
class AssistantMessage(BaseModel):
|
79
|
+
text: str
|
80
|
+
|
81
|
+
def render(self) -> LLMMessage:
|
82
|
+
return {"role": "assistant", "content": [{"type": "text", "text": self.text}]}
|
83
|
+
|
84
|
+
|
85
|
+
class SystemMessage(BaseModel):
|
86
|
+
parts: list[TextMessagePart | MediaMessagePart]
|
87
|
+
|
88
|
+
def render(self) -> LLMMessage:
|
89
|
+
return {
|
90
|
+
"role": "system",
|
91
|
+
"content": [part.render() for part in self.parts],
|
92
|
+
}
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[tool.poetry]
|
2
2
|
name = "goose-py"
|
3
|
-
version = "0.
|
3
|
+
version = "0.6.0"
|
4
4
|
description = "A tool for AI workflows based on human-computer collaboration and structured output."
|
5
5
|
authors = [
|
6
6
|
"Nash Taylor <nash@chelle.ai>",
|
@@ -53,6 +53,10 @@ reportImportCycles = false
|
|
53
53
|
reportUnknownMemberType = false
|
54
54
|
reportUnknownVariableType = false
|
55
55
|
stubPath = ".stubs"
|
56
|
+
exclude = [
|
57
|
+
"goose/__init__.py",
|
58
|
+
".venv",
|
59
|
+
]
|
56
60
|
|
57
61
|
[tool.pytest.ini_options]
|
58
62
|
filterwarnings = [
|
goose_py-0.5.0/goose/flow.py
DELETED
@@ -1,458 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
from contextlib import asynccontextmanager
|
3
|
-
from contextvars import ContextVar
|
4
|
-
from types import CodeType
|
5
|
-
from typing import (
|
6
|
-
Any,
|
7
|
-
AsyncIterator,
|
8
|
-
Awaitable,
|
9
|
-
Callable,
|
10
|
-
NewType,
|
11
|
-
Protocol,
|
12
|
-
Self,
|
13
|
-
overload,
|
14
|
-
)
|
15
|
-
|
16
|
-
from pydantic import BaseModel
|
17
|
-
|
18
|
-
from goose.agent import (
|
19
|
-
Agent,
|
20
|
-
AssistantMessage,
|
21
|
-
GeminiModel,
|
22
|
-
IAgentLogger,
|
23
|
-
LLMMessage,
|
24
|
-
SystemMessage,
|
25
|
-
UserMessage,
|
26
|
-
)
|
27
|
-
from goose.errors import Honk
|
28
|
-
from goose.result import Result, TextResult
|
29
|
-
from goose.store import IFlowRunStore, InMemoryFlowRunStore
|
30
|
-
|
31
|
-
SerializedFlowRun = NewType("SerializedFlowRun", str)
|
32
|
-
|
33
|
-
|
34
|
-
class Conversation[R: Result](BaseModel):
|
35
|
-
user_messages: list[UserMessage]
|
36
|
-
result_messages: list[R]
|
37
|
-
context: SystemMessage | None = None
|
38
|
-
|
39
|
-
@property
|
40
|
-
def awaiting_response(self) -> bool:
|
41
|
-
return len(self.user_messages) == len(self.result_messages)
|
42
|
-
|
43
|
-
def render(self) -> list[LLMMessage]:
|
44
|
-
messages: list[LLMMessage] = []
|
45
|
-
if self.context is not None:
|
46
|
-
messages.append(self.context.render())
|
47
|
-
|
48
|
-
for message_index in range(len(self.user_messages)):
|
49
|
-
messages.append(
|
50
|
-
AssistantMessage(
|
51
|
-
text=self.result_messages[message_index].model_dump_json()
|
52
|
-
).render()
|
53
|
-
)
|
54
|
-
messages.append(self.user_messages[message_index].render())
|
55
|
-
|
56
|
-
if len(self.result_messages) > len(self.user_messages):
|
57
|
-
messages.append(
|
58
|
-
AssistantMessage(
|
59
|
-
text=self.result_messages[-1].model_dump_json()
|
60
|
-
).render()
|
61
|
-
)
|
62
|
-
|
63
|
-
return messages
|
64
|
-
|
65
|
-
|
66
|
-
class IAdapter[ResultT: Result](Protocol):
|
67
|
-
__code__: CodeType
|
68
|
-
|
69
|
-
async def __call__(
|
70
|
-
self, *, conversation: Conversation[ResultT], agent: Agent
|
71
|
-
) -> ResultT: ...
|
72
|
-
|
73
|
-
|
74
|
-
class NodeState[ResultT: Result](BaseModel):
|
75
|
-
task_name: str
|
76
|
-
index: int
|
77
|
-
conversation: Conversation[ResultT]
|
78
|
-
last_hash: int
|
79
|
-
|
80
|
-
@property
|
81
|
-
def result(self) -> ResultT:
|
82
|
-
if len(self.conversation.result_messages) == 0:
|
83
|
-
raise Honk("Node awaiting response, has no result")
|
84
|
-
|
85
|
-
return self.conversation.result_messages[-1]
|
86
|
-
|
87
|
-
def set_context(self, *, context: SystemMessage) -> Self:
|
88
|
-
self.conversation.context = context
|
89
|
-
return self
|
90
|
-
|
91
|
-
def add_result(
|
92
|
-
self,
|
93
|
-
*,
|
94
|
-
result: ResultT,
|
95
|
-
new_hash: int | None = None,
|
96
|
-
overwrite: bool = False,
|
97
|
-
) -> Self:
|
98
|
-
if overwrite and len(self.conversation.result_messages) > 0:
|
99
|
-
self.conversation.result_messages[-1] = result
|
100
|
-
else:
|
101
|
-
self.conversation.result_messages.append(result)
|
102
|
-
if new_hash is not None:
|
103
|
-
self.last_hash = new_hash
|
104
|
-
return self
|
105
|
-
|
106
|
-
def add_user_message(self, *, message: UserMessage) -> Self:
|
107
|
-
self.conversation.user_messages.append(message)
|
108
|
-
return self
|
109
|
-
|
110
|
-
|
111
|
-
class FlowRun:
|
112
|
-
def __init__(self) -> None:
|
113
|
-
self._node_states: dict[tuple[str, int], str] = {}
|
114
|
-
self._last_requested_indices: dict[str, int] = {}
|
115
|
-
self._flow_name = ""
|
116
|
-
self._id = ""
|
117
|
-
self._agent: Agent | None = None
|
118
|
-
self._flow_args: tuple[Any, ...] | None = None
|
119
|
-
self._flow_kwargs: dict[str, Any] | None = None
|
120
|
-
|
121
|
-
@property
|
122
|
-
def flow_name(self) -> str:
|
123
|
-
return self._flow_name
|
124
|
-
|
125
|
-
@property
|
126
|
-
def id(self) -> str:
|
127
|
-
return self._id
|
128
|
-
|
129
|
-
@property
|
130
|
-
def agent(self) -> Agent:
|
131
|
-
if self._agent is None:
|
132
|
-
raise Honk("Agent is only accessible once a run is started")
|
133
|
-
return self._agent
|
134
|
-
|
135
|
-
@property
|
136
|
-
def flow_inputs(self) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
137
|
-
if self._flow_args is None or self._flow_kwargs is None:
|
138
|
-
raise Honk("This Flow run has not been executed before")
|
139
|
-
|
140
|
-
return self._flow_args, self._flow_kwargs
|
141
|
-
|
142
|
-
def get_all[R: Result](self, *, task: "Task[Any, R]") -> list[NodeState[R]]:
|
143
|
-
matching_nodes: list[NodeState[R]] = []
|
144
|
-
for key, node_state in self._node_states.items():
|
145
|
-
if key[0] == task.name:
|
146
|
-
matching_nodes.append(
|
147
|
-
NodeState[task.result_type].model_validate_json(node_state)
|
148
|
-
)
|
149
|
-
return sorted(matching_nodes, key=lambda node: node.index)
|
150
|
-
|
151
|
-
def get[R: Result](self, *, task: "Task[Any, R]", index: int = 0) -> NodeState[R]:
|
152
|
-
if (
|
153
|
-
existing_node_state := self._node_states.get((task.name, index))
|
154
|
-
) is not None:
|
155
|
-
return NodeState[task.result_type].model_validate_json(existing_node_state)
|
156
|
-
else:
|
157
|
-
return NodeState[task.result_type](
|
158
|
-
task_name=task.name,
|
159
|
-
index=index,
|
160
|
-
conversation=Conversation[task.result_type](
|
161
|
-
user_messages=[], result_messages=[]
|
162
|
-
),
|
163
|
-
last_hash=0,
|
164
|
-
)
|
165
|
-
|
166
|
-
def set_flow_inputs(self, *args: Any, **kwargs: Any) -> None:
|
167
|
-
self._flow_args = args
|
168
|
-
self._flow_kwargs = kwargs
|
169
|
-
|
170
|
-
def add_node_state(self, node_state: NodeState[Any], /) -> None:
|
171
|
-
key = (node_state.task_name, node_state.index)
|
172
|
-
self._node_states[key] = node_state.model_dump_json()
|
173
|
-
|
174
|
-
def get_next[R: Result](self, *, task: "Task[Any, R]") -> NodeState[R]:
|
175
|
-
if task.name not in self._last_requested_indices:
|
176
|
-
self._last_requested_indices[task.name] = 0
|
177
|
-
else:
|
178
|
-
self._last_requested_indices[task.name] += 1
|
179
|
-
|
180
|
-
return self.get(task=task, index=self._last_requested_indices[task.name])
|
181
|
-
|
182
|
-
def start(
|
183
|
-
self,
|
184
|
-
*,
|
185
|
-
flow_name: str,
|
186
|
-
run_id: str,
|
187
|
-
agent_logger: IAgentLogger | None = None,
|
188
|
-
) -> None:
|
189
|
-
self._last_requested_indices = {}
|
190
|
-
self._flow_name = flow_name
|
191
|
-
self._id = run_id
|
192
|
-
self._agent = Agent(
|
193
|
-
flow_name=self.flow_name, run_id=self.id, logger=agent_logger
|
194
|
-
)
|
195
|
-
|
196
|
-
def end(self) -> None:
|
197
|
-
self._last_requested_indices = {}
|
198
|
-
self._flow_name = ""
|
199
|
-
self._id = ""
|
200
|
-
self._agent = None
|
201
|
-
|
202
|
-
def clear_node(self, *, task: "Task[Any, Result]", index: int) -> None:
|
203
|
-
key = (task.name, index)
|
204
|
-
if key in self._node_states:
|
205
|
-
del self._node_states[key]
|
206
|
-
|
207
|
-
def dump(self) -> SerializedFlowRun:
|
208
|
-
flow_args, flow_kwargs = self.flow_inputs
|
209
|
-
|
210
|
-
return SerializedFlowRun(
|
211
|
-
json.dumps(
|
212
|
-
{
|
213
|
-
"node_states": {
|
214
|
-
":".join([task_name, str(index)]): value
|
215
|
-
for (task_name, index), value in self._node_states.items()
|
216
|
-
},
|
217
|
-
"flow_args": list(flow_args),
|
218
|
-
"flow_kwargs": flow_kwargs,
|
219
|
-
}
|
220
|
-
)
|
221
|
-
)
|
222
|
-
|
223
|
-
@classmethod
|
224
|
-
def load(cls, serialized_flow_run: SerializedFlowRun, /) -> Self:
|
225
|
-
flow_run = cls()
|
226
|
-
run = json.loads(serialized_flow_run)
|
227
|
-
|
228
|
-
new_node_states: dict[tuple[str, int], str] = {}
|
229
|
-
for key, node_state in run["node_states"].items():
|
230
|
-
task_name, index = tuple(key.split(":"))
|
231
|
-
new_node_states[(task_name, int(index))] = node_state
|
232
|
-
flow_run._node_states = new_node_states
|
233
|
-
|
234
|
-
flow_run._flow_args = tuple(run["flow_args"])
|
235
|
-
flow_run._flow_kwargs = run["flow_kwargs"]
|
236
|
-
|
237
|
-
return flow_run
|
238
|
-
|
239
|
-
|
240
|
-
_current_flow_run: ContextVar[FlowRun | None] = ContextVar(
|
241
|
-
"current_flow_run", default=None
|
242
|
-
)
|
243
|
-
|
244
|
-
|
245
|
-
class Flow[**P]:
|
246
|
-
def __init__(
|
247
|
-
self,
|
248
|
-
fn: Callable[P, Awaitable[None]],
|
249
|
-
/,
|
250
|
-
*,
|
251
|
-
name: str | None = None,
|
252
|
-
store: IFlowRunStore | None = None,
|
253
|
-
agent_logger: IAgentLogger | None = None,
|
254
|
-
) -> None:
|
255
|
-
self._fn = fn
|
256
|
-
self._name = name
|
257
|
-
self._agent_logger = agent_logger
|
258
|
-
self._store = store or InMemoryFlowRunStore(flow_name=self.name)
|
259
|
-
|
260
|
-
@property
|
261
|
-
def name(self) -> str:
|
262
|
-
return self._name or self._fn.__name__
|
263
|
-
|
264
|
-
@property
|
265
|
-
def current_run(self) -> FlowRun:
|
266
|
-
run = _current_flow_run.get()
|
267
|
-
if run is None:
|
268
|
-
raise Honk("No current flow run")
|
269
|
-
return run
|
270
|
-
|
271
|
-
@asynccontextmanager
|
272
|
-
async def start_run(self, *, run_id: str) -> AsyncIterator[FlowRun]:
|
273
|
-
existing_run = await self._store.get(run_id=run_id)
|
274
|
-
if existing_run is None:
|
275
|
-
run = FlowRun()
|
276
|
-
else:
|
277
|
-
run = existing_run
|
278
|
-
|
279
|
-
old_run = _current_flow_run.get()
|
280
|
-
_current_flow_run.set(run)
|
281
|
-
|
282
|
-
run.start(flow_name=self.name, run_id=run_id, agent_logger=self._agent_logger)
|
283
|
-
yield run
|
284
|
-
await self._store.save(run=run)
|
285
|
-
run.end()
|
286
|
-
|
287
|
-
_current_flow_run.set(old_run)
|
288
|
-
|
289
|
-
async def generate(self, *args: P.args, **kwargs: P.kwargs) -> None:
|
290
|
-
flow_run = _current_flow_run.get()
|
291
|
-
if flow_run is None:
|
292
|
-
raise Honk("No current flow run")
|
293
|
-
|
294
|
-
flow_run.set_flow_inputs(*args, **kwargs)
|
295
|
-
await self._fn(*args, **kwargs)
|
296
|
-
|
297
|
-
async def regenerate(self) -> None:
|
298
|
-
flow_run = _current_flow_run.get()
|
299
|
-
if flow_run is None:
|
300
|
-
raise Honk("No current flow run")
|
301
|
-
|
302
|
-
flow_args, flow_kwargs = flow_run.flow_inputs
|
303
|
-
await self._fn(*flow_args, **flow_kwargs)
|
304
|
-
|
305
|
-
|
306
|
-
class Task[**P, R: Result]:
|
307
|
-
def __init__(
|
308
|
-
self,
|
309
|
-
generator: Callable[P, Awaitable[R]],
|
310
|
-
/,
|
311
|
-
*,
|
312
|
-
retries: int = 0,
|
313
|
-
adapter_model: GeminiModel = GeminiModel.FLASH,
|
314
|
-
) -> None:
|
315
|
-
self._generator = generator
|
316
|
-
self._retries = retries
|
317
|
-
self._adapter_model = adapter_model
|
318
|
-
self._adapter_model = adapter_model
|
319
|
-
|
320
|
-
@property
|
321
|
-
def result_type(self) -> type[R]:
|
322
|
-
result_type = self._generator.__annotations__.get("return")
|
323
|
-
if result_type is None:
|
324
|
-
raise Honk(f"Task {self.name} has no return type annotation")
|
325
|
-
return result_type
|
326
|
-
|
327
|
-
@property
|
328
|
-
def name(self) -> str:
|
329
|
-
return self._generator.__name__
|
330
|
-
|
331
|
-
async def generate(
|
332
|
-
self, state: NodeState[R], *args: P.args, **kwargs: P.kwargs
|
333
|
-
) -> R:
|
334
|
-
state_hash = self.__hash_task_call(*args, **kwargs)
|
335
|
-
if state_hash != state.last_hash:
|
336
|
-
result = await self._generator(*args, **kwargs)
|
337
|
-
state.add_result(result=result, new_hash=state_hash, overwrite=True)
|
338
|
-
return result
|
339
|
-
else:
|
340
|
-
return state.result
|
341
|
-
|
342
|
-
async def jam(
|
343
|
-
self,
|
344
|
-
*,
|
345
|
-
user_message: UserMessage,
|
346
|
-
context: SystemMessage | None = None,
|
347
|
-
index: int = 0,
|
348
|
-
) -> R:
|
349
|
-
flow_run = self.__get_current_flow_run()
|
350
|
-
node_state = flow_run.get(task=self, index=index)
|
351
|
-
|
352
|
-
if context is not None:
|
353
|
-
node_state.set_context(context=context)
|
354
|
-
node_state.add_user_message(message=user_message)
|
355
|
-
|
356
|
-
result = await self.__adapt(
|
357
|
-
conversation=node_state.conversation, agent=flow_run.agent
|
358
|
-
)
|
359
|
-
node_state.add_result(result=result)
|
360
|
-
flow_run.add_node_state(node_state)
|
361
|
-
|
362
|
-
return result
|
363
|
-
|
364
|
-
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
365
|
-
flow_run = self.__get_current_flow_run()
|
366
|
-
node_state = flow_run.get_next(task=self)
|
367
|
-
result = await self.generate(node_state, *args, **kwargs)
|
368
|
-
flow_run.add_node_state(node_state)
|
369
|
-
return result
|
370
|
-
|
371
|
-
async def __adapt(self, *, conversation: Conversation[R], agent: Agent) -> R:
|
372
|
-
messages: list[UserMessage | AssistantMessage] = []
|
373
|
-
for message_index in range(len(conversation.user_messages)):
|
374
|
-
user_message = conversation.user_messages[message_index]
|
375
|
-
result = conversation.result_messages[message_index]
|
376
|
-
|
377
|
-
if isinstance(result, TextResult):
|
378
|
-
assistant_text = result.text
|
379
|
-
else:
|
380
|
-
assistant_text = result.model_dump_json()
|
381
|
-
assistant_message = AssistantMessage(text=assistant_text)
|
382
|
-
messages.append(assistant_message)
|
383
|
-
messages.append(user_message)
|
384
|
-
|
385
|
-
return await agent(
|
386
|
-
messages=messages,
|
387
|
-
model=self._adapter_model,
|
388
|
-
task_name=f"adapt--{self.name}",
|
389
|
-
system=conversation.context,
|
390
|
-
response_model=self.result_type,
|
391
|
-
)
|
392
|
-
|
393
|
-
def __hash_task_call(self, *args: P.args, **kwargs: P.kwargs) -> int:
|
394
|
-
try:
|
395
|
-
to_hash = str(
|
396
|
-
tuple(args)
|
397
|
-
+ tuple(kwargs.values())
|
398
|
-
+ (self._generator.__code__, self._adapter_model)
|
399
|
-
)
|
400
|
-
return hash(to_hash)
|
401
|
-
except TypeError:
|
402
|
-
raise Honk(f"Unhashable argument to task {self.name}: {args} {kwargs}")
|
403
|
-
|
404
|
-
def __get_current_flow_run(self) -> FlowRun:
|
405
|
-
run = _current_flow_run.get()
|
406
|
-
if run is None:
|
407
|
-
raise Honk("No current flow run")
|
408
|
-
return run
|
409
|
-
|
410
|
-
|
411
|
-
@overload
|
412
|
-
def task[**P, R: Result](generator: Callable[P, Awaitable[R]], /) -> Task[P, R]: ...
|
413
|
-
@overload
|
414
|
-
def task[**P, R: Result](
|
415
|
-
*, retries: int = 0, adapter_model: GeminiModel = GeminiModel.FLASH
|
416
|
-
) -> Callable[[Callable[P, Awaitable[R]]], Task[P, R]]: ...
|
417
|
-
def task[**P, R: Result](
|
418
|
-
generator: Callable[P, Awaitable[R]] | None = None,
|
419
|
-
/,
|
420
|
-
*,
|
421
|
-
retries: int = 0,
|
422
|
-
adapter_model: GeminiModel = GeminiModel.FLASH,
|
423
|
-
) -> Task[P, R] | Callable[[Callable[P, Awaitable[R]]], Task[P, R]]:
|
424
|
-
if generator is None:
|
425
|
-
|
426
|
-
def decorator(fn: Callable[P, Awaitable[R]]) -> Task[P, R]:
|
427
|
-
return Task(fn, retries=retries, adapter_model=adapter_model)
|
428
|
-
|
429
|
-
return decorator
|
430
|
-
|
431
|
-
return Task(generator, retries=retries, adapter_model=adapter_model)
|
432
|
-
|
433
|
-
|
434
|
-
@overload
|
435
|
-
def flow[**P](fn: Callable[P, Awaitable[None]], /) -> Flow[P]: ...
|
436
|
-
@overload
|
437
|
-
def flow[**P](
|
438
|
-
*,
|
439
|
-
name: str | None = None,
|
440
|
-
store: IFlowRunStore | None = None,
|
441
|
-
agent_logger: IAgentLogger | None = None,
|
442
|
-
) -> Callable[[Callable[P, Awaitable[None]]], Flow[P]]: ...
|
443
|
-
def flow[**P](
|
444
|
-
fn: Callable[P, Awaitable[None]] | None = None,
|
445
|
-
/,
|
446
|
-
*,
|
447
|
-
name: str | None = None,
|
448
|
-
store: IFlowRunStore | None = None,
|
449
|
-
agent_logger: IAgentLogger | None = None,
|
450
|
-
) -> Flow[P] | Callable[[Callable[P, Awaitable[None]]], Flow[P]]:
|
451
|
-
if fn is None:
|
452
|
-
|
453
|
-
def decorator(fn: Callable[P, Awaitable[None]]) -> Flow[P]:
|
454
|
-
return Flow(fn, name=name, store=store, agent_logger=agent_logger)
|
455
|
-
|
456
|
-
return decorator
|
457
|
-
|
458
|
-
return Flow(fn, name=name, store=store, agent_logger=agent_logger)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|