goose-py 0.5.1__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- goose/__init__.py +4 -0
- goose/_internal/agent.py +201 -0
- goose/_internal/conversation.py +41 -0
- goose/_internal/flow.py +106 -0
- goose/_internal/state.py +190 -0
- goose/{store.py → _internal/store.py} +8 -6
- goose/_internal/task.py +136 -0
- goose/_internal/types/__init__.py +0 -0
- goose/_internal/types/agent.py +92 -0
- goose/agent.py +28 -283
- goose/flow.py +2 -457
- goose/runs.py +4 -0
- goose_py-0.7.0.dist-info/METADATA +14 -0
- goose_py-0.7.0.dist-info/RECORD +18 -0
- {goose_py-0.5.1.dist-info → goose_py-0.7.0.dist-info}/WHEEL +1 -1
- goose_py-0.5.1.dist-info/METADATA +0 -31
- goose_py-0.5.1.dist-info/RECORD +0 -10
- /goose/{result.py → _internal/result.py} +0 -0
goose/__init__.py
CHANGED
goose/_internal/agent.py
ADDED
@@ -0,0 +1,201 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
from datetime import datetime
|
4
|
+
from typing import Any, ClassVar, Protocol, TypedDict
|
5
|
+
|
6
|
+
from litellm import acompletion
|
7
|
+
from pydantic import BaseModel, computed_field
|
8
|
+
|
9
|
+
from goose._internal.result import Result, TextResult
|
10
|
+
from goose._internal.types.agent import (
|
11
|
+
AssistantMessage,
|
12
|
+
GeminiModel,
|
13
|
+
SystemMessage,
|
14
|
+
UserMessage,
|
15
|
+
)
|
16
|
+
|
17
|
+
|
18
|
+
class AgentResponseDump(TypedDict):
|
19
|
+
run_id: str
|
20
|
+
flow_name: str
|
21
|
+
task_name: str
|
22
|
+
model: str
|
23
|
+
system_message: str
|
24
|
+
input_messages: list[str]
|
25
|
+
output_message: str
|
26
|
+
input_cost: float
|
27
|
+
output_cost: float
|
28
|
+
total_cost: float
|
29
|
+
input_tokens: int
|
30
|
+
output_tokens: int
|
31
|
+
start_time: datetime
|
32
|
+
end_time: datetime
|
33
|
+
duration_ms: int
|
34
|
+
|
35
|
+
|
36
|
+
class AgentResponse[R: BaseModel | str](BaseModel):
|
37
|
+
INPUT_CENTS_PER_MILLION_TOKENS: ClassVar[dict[GeminiModel, float]] = {
|
38
|
+
GeminiModel.FLASH_8B: 30,
|
39
|
+
GeminiModel.FLASH: 15,
|
40
|
+
GeminiModel.PRO: 500,
|
41
|
+
}
|
42
|
+
OUTPUT_CENTS_PER_MILLION_TOKENS: ClassVar[dict[GeminiModel, float]] = {
|
43
|
+
GeminiModel.FLASH_8B: 30,
|
44
|
+
GeminiModel.FLASH: 15,
|
45
|
+
GeminiModel.PRO: 500,
|
46
|
+
}
|
47
|
+
|
48
|
+
response: R
|
49
|
+
run_id: str
|
50
|
+
flow_name: str
|
51
|
+
task_name: str
|
52
|
+
model: GeminiModel
|
53
|
+
system: SystemMessage | None = None
|
54
|
+
input_messages: list[UserMessage | AssistantMessage]
|
55
|
+
input_tokens: int
|
56
|
+
output_tokens: int
|
57
|
+
start_time: datetime
|
58
|
+
end_time: datetime
|
59
|
+
|
60
|
+
@computed_field
|
61
|
+
@property
|
62
|
+
def duration_ms(self) -> int:
|
63
|
+
return int((self.end_time - self.start_time).total_seconds() * 1000)
|
64
|
+
|
65
|
+
@computed_field
|
66
|
+
@property
|
67
|
+
def input_cost(self) -> float:
|
68
|
+
return (
|
69
|
+
self.INPUT_CENTS_PER_MILLION_TOKENS[self.model]
|
70
|
+
* self.input_tokens
|
71
|
+
/ 1_000_000
|
72
|
+
)
|
73
|
+
|
74
|
+
@computed_field
|
75
|
+
@property
|
76
|
+
def output_cost(self) -> float:
|
77
|
+
return (
|
78
|
+
self.OUTPUT_CENTS_PER_MILLION_TOKENS[self.model]
|
79
|
+
* self.output_tokens
|
80
|
+
/ 1_000_000
|
81
|
+
)
|
82
|
+
|
83
|
+
@computed_field
|
84
|
+
@property
|
85
|
+
def total_cost(self) -> float:
|
86
|
+
return self.input_cost + self.output_cost
|
87
|
+
|
88
|
+
def minimized_dump(self) -> AgentResponseDump:
|
89
|
+
if self.system is None:
|
90
|
+
minimized_system_message = ""
|
91
|
+
else:
|
92
|
+
minimized_system_message = self.system.render()
|
93
|
+
for part in minimized_system_message["content"]:
|
94
|
+
if part["type"] == "image_url":
|
95
|
+
part["image_url"] = "__MEDIA__"
|
96
|
+
minimized_system_message = json.dumps(minimized_system_message)
|
97
|
+
|
98
|
+
minimized_input_messages = [message.render() for message in self.input_messages]
|
99
|
+
for message in minimized_input_messages:
|
100
|
+
for part in message["content"]:
|
101
|
+
if part["type"] == "image_url":
|
102
|
+
part["image_url"] = "__MEDIA__"
|
103
|
+
minimized_input_messages = [
|
104
|
+
json.dumps(message) for message in minimized_input_messages
|
105
|
+
]
|
106
|
+
|
107
|
+
output_message = (
|
108
|
+
self.response.model_dump_json()
|
109
|
+
if isinstance(self.response, BaseModel)
|
110
|
+
else self.response
|
111
|
+
)
|
112
|
+
|
113
|
+
return {
|
114
|
+
"run_id": self.run_id,
|
115
|
+
"flow_name": self.flow_name,
|
116
|
+
"task_name": self.task_name,
|
117
|
+
"model": self.model.value,
|
118
|
+
"system_message": minimized_system_message,
|
119
|
+
"input_messages": minimized_input_messages,
|
120
|
+
"output_message": output_message,
|
121
|
+
"input_tokens": self.input_tokens,
|
122
|
+
"output_tokens": self.output_tokens,
|
123
|
+
"input_cost": self.input_cost,
|
124
|
+
"output_cost": self.output_cost,
|
125
|
+
"total_cost": self.total_cost,
|
126
|
+
"start_time": self.start_time,
|
127
|
+
"end_time": self.end_time,
|
128
|
+
"duration_ms": self.duration_ms,
|
129
|
+
}
|
130
|
+
|
131
|
+
|
132
|
+
class IAgentLogger(Protocol):
|
133
|
+
async def __call__(self, *, response: AgentResponse[Any]) -> None: ...
|
134
|
+
|
135
|
+
|
136
|
+
class Agent:
|
137
|
+
def __init__(
|
138
|
+
self,
|
139
|
+
*,
|
140
|
+
flow_name: str,
|
141
|
+
run_id: str,
|
142
|
+
logger: IAgentLogger | None = None,
|
143
|
+
) -> None:
|
144
|
+
self.flow_name = flow_name
|
145
|
+
self.run_id = run_id
|
146
|
+
self.logger = logger
|
147
|
+
|
148
|
+
async def __call__[R: Result](
|
149
|
+
self,
|
150
|
+
*,
|
151
|
+
messages: list[UserMessage | AssistantMessage],
|
152
|
+
model: GeminiModel,
|
153
|
+
task_name: str,
|
154
|
+
response_model: type[R] = TextResult,
|
155
|
+
system: SystemMessage | None = None,
|
156
|
+
) -> R:
|
157
|
+
start_time = datetime.now()
|
158
|
+
rendered_messages = [message.render() for message in messages]
|
159
|
+
if system is not None:
|
160
|
+
rendered_messages.insert(0, system.render())
|
161
|
+
|
162
|
+
if response_model is TextResult:
|
163
|
+
response = await acompletion(model=model.value, messages=rendered_messages)
|
164
|
+
parsed_response = response_model.model_validate(
|
165
|
+
{"text": response.choices[0].message.content}
|
166
|
+
)
|
167
|
+
else:
|
168
|
+
response = await acompletion(
|
169
|
+
model=model.value,
|
170
|
+
messages=rendered_messages,
|
171
|
+
response_format={
|
172
|
+
"type": "json_object",
|
173
|
+
"response_schema": response_model.model_json_schema(),
|
174
|
+
"enforce_validation": True,
|
175
|
+
},
|
176
|
+
)
|
177
|
+
parsed_response = response_model.model_validate_json(
|
178
|
+
response.choices[0].message.content
|
179
|
+
)
|
180
|
+
|
181
|
+
end_time = datetime.now()
|
182
|
+
agent_response = AgentResponse(
|
183
|
+
response=parsed_response,
|
184
|
+
run_id=self.run_id,
|
185
|
+
flow_name=self.flow_name,
|
186
|
+
task_name=task_name,
|
187
|
+
model=model,
|
188
|
+
system=system,
|
189
|
+
input_messages=messages,
|
190
|
+
input_tokens=response.usage.prompt_tokens,
|
191
|
+
output_tokens=response.usage.completion_tokens,
|
192
|
+
start_time=start_time,
|
193
|
+
end_time=end_time,
|
194
|
+
)
|
195
|
+
|
196
|
+
if self.logger is not None:
|
197
|
+
await self.logger(response=agent_response)
|
198
|
+
else:
|
199
|
+
logging.info(agent_response.model_dump())
|
200
|
+
|
201
|
+
return parsed_response
|
@@ -0,0 +1,41 @@
|
|
1
|
+
from pydantic import BaseModel
|
2
|
+
|
3
|
+
from goose._internal.result import Result
|
4
|
+
from goose._internal.types.agent import (
|
5
|
+
AssistantMessage,
|
6
|
+
LLMMessage,
|
7
|
+
SystemMessage,
|
8
|
+
UserMessage,
|
9
|
+
)
|
10
|
+
|
11
|
+
|
12
|
+
class Conversation[R: Result](BaseModel):
|
13
|
+
user_messages: list[UserMessage]
|
14
|
+
result_messages: list[R]
|
15
|
+
context: SystemMessage | None = None
|
16
|
+
|
17
|
+
@property
|
18
|
+
def awaiting_response(self) -> bool:
|
19
|
+
return len(self.user_messages) == len(self.result_messages)
|
20
|
+
|
21
|
+
def render(self) -> list[LLMMessage]:
|
22
|
+
messages: list[LLMMessage] = []
|
23
|
+
if self.context is not None:
|
24
|
+
messages.append(self.context.render())
|
25
|
+
|
26
|
+
for message_index in range(len(self.user_messages)):
|
27
|
+
messages.append(
|
28
|
+
AssistantMessage(
|
29
|
+
text=self.result_messages[message_index].model_dump_json()
|
30
|
+
).render()
|
31
|
+
)
|
32
|
+
messages.append(self.user_messages[message_index].render())
|
33
|
+
|
34
|
+
if len(self.result_messages) > len(self.user_messages):
|
35
|
+
messages.append(
|
36
|
+
AssistantMessage(
|
37
|
+
text=self.result_messages[-1].model_dump_json()
|
38
|
+
).render()
|
39
|
+
)
|
40
|
+
|
41
|
+
return messages
|
goose/_internal/flow.py
ADDED
@@ -0,0 +1,106 @@
|
|
1
|
+
from contextlib import asynccontextmanager
|
2
|
+
from types import CodeType
|
3
|
+
from typing import AsyncIterator, Awaitable, Callable, Protocol, overload
|
4
|
+
|
5
|
+
from goose._internal.agent import Agent, IAgentLogger
|
6
|
+
from goose._internal.conversation import Conversation
|
7
|
+
from goose._internal.result import Result
|
8
|
+
from goose._internal.state import FlowRun, get_current_flow_run, set_current_flow_run
|
9
|
+
from goose._internal.store import IFlowRunStore, InMemoryFlowRunStore
|
10
|
+
from goose.errors import Honk
|
11
|
+
|
12
|
+
|
13
|
+
class IAdapter[ResultT: Result](Protocol):
|
14
|
+
__code__: CodeType
|
15
|
+
|
16
|
+
async def __call__(
|
17
|
+
self, *, conversation: Conversation[ResultT], agent: Agent
|
18
|
+
) -> ResultT: ...
|
19
|
+
|
20
|
+
|
21
|
+
class Flow[**P]:
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
fn: Callable[P, Awaitable[None]],
|
25
|
+
/,
|
26
|
+
*,
|
27
|
+
name: str | None = None,
|
28
|
+
store: IFlowRunStore | None = None,
|
29
|
+
agent_logger: IAgentLogger | None = None,
|
30
|
+
) -> None:
|
31
|
+
self._fn = fn
|
32
|
+
self._name = name
|
33
|
+
self._agent_logger = agent_logger
|
34
|
+
self._store = store or InMemoryFlowRunStore(flow_name=self.name)
|
35
|
+
|
36
|
+
@property
|
37
|
+
def name(self) -> str:
|
38
|
+
return self._name or self._fn.__name__
|
39
|
+
|
40
|
+
@property
|
41
|
+
def current_run(self) -> FlowRun:
|
42
|
+
run = get_current_flow_run()
|
43
|
+
if run is None:
|
44
|
+
raise Honk("No current flow run")
|
45
|
+
return run
|
46
|
+
|
47
|
+
@asynccontextmanager
|
48
|
+
async def start_run(self, *, run_id: str) -> AsyncIterator[FlowRun]:
|
49
|
+
existing_run = await self._store.get(run_id=run_id)
|
50
|
+
if existing_run is None:
|
51
|
+
run = FlowRun()
|
52
|
+
else:
|
53
|
+
run = existing_run
|
54
|
+
|
55
|
+
old_run = get_current_flow_run()
|
56
|
+
set_current_flow_run(run)
|
57
|
+
|
58
|
+
run.start(flow_name=self.name, run_id=run_id, agent_logger=self._agent_logger)
|
59
|
+
yield run
|
60
|
+
await self._store.save(run=run)
|
61
|
+
run.end()
|
62
|
+
|
63
|
+
set_current_flow_run(old_run)
|
64
|
+
|
65
|
+
async def generate(self, *args: P.args, **kwargs: P.kwargs) -> None:
|
66
|
+
flow_run = get_current_flow_run()
|
67
|
+
if flow_run is None:
|
68
|
+
raise Honk("No current flow run")
|
69
|
+
|
70
|
+
flow_run.set_flow_inputs(*args, **kwargs)
|
71
|
+
await self._fn(*args, **kwargs)
|
72
|
+
|
73
|
+
async def regenerate(self) -> None:
|
74
|
+
flow_run = get_current_flow_run()
|
75
|
+
if flow_run is None:
|
76
|
+
raise Honk("No current flow run")
|
77
|
+
|
78
|
+
flow_args, flow_kwargs = flow_run.flow_inputs
|
79
|
+
await self._fn(*flow_args, **flow_kwargs)
|
80
|
+
|
81
|
+
|
82
|
+
@overload
|
83
|
+
def flow[**P](fn: Callable[P, Awaitable[None]], /) -> Flow[P]: ...
|
84
|
+
@overload
|
85
|
+
def flow[**P](
|
86
|
+
*,
|
87
|
+
name: str | None = None,
|
88
|
+
store: IFlowRunStore | None = None,
|
89
|
+
agent_logger: IAgentLogger | None = None,
|
90
|
+
) -> Callable[[Callable[P, Awaitable[None]]], Flow[P]]: ...
|
91
|
+
def flow[**P](
|
92
|
+
fn: Callable[P, Awaitable[None]] | None = None,
|
93
|
+
/,
|
94
|
+
*,
|
95
|
+
name: str | None = None,
|
96
|
+
store: IFlowRunStore | None = None,
|
97
|
+
agent_logger: IAgentLogger | None = None,
|
98
|
+
) -> Flow[P] | Callable[[Callable[P, Awaitable[None]]], Flow[P]]:
|
99
|
+
if fn is None:
|
100
|
+
|
101
|
+
def decorator(fn: Callable[P, Awaitable[None]]) -> Flow[P]:
|
102
|
+
return Flow(fn, name=name, store=store, agent_logger=agent_logger)
|
103
|
+
|
104
|
+
return decorator
|
105
|
+
|
106
|
+
return Flow(fn, name=name, store=store, agent_logger=agent_logger)
|
goose/_internal/state.py
ADDED
@@ -0,0 +1,190 @@
|
|
1
|
+
from contextvars import ContextVar
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from typing import TYPE_CHECKING, Any, Self
|
4
|
+
|
5
|
+
from pydantic import BaseModel
|
6
|
+
|
7
|
+
from goose._internal.agent import (
|
8
|
+
Agent,
|
9
|
+
IAgentLogger,
|
10
|
+
SystemMessage,
|
11
|
+
UserMessage,
|
12
|
+
)
|
13
|
+
from goose._internal.conversation import Conversation
|
14
|
+
from goose._internal.result import Result
|
15
|
+
from goose.errors import Honk
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
from goose._internal.task import Task
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass
|
22
|
+
class FlowRunState:
|
23
|
+
node_states: dict[tuple[str, int], str]
|
24
|
+
flow_args: tuple[Any, ...]
|
25
|
+
flow_kwargs: dict[str, Any]
|
26
|
+
|
27
|
+
|
28
|
+
class NodeState[ResultT: Result](BaseModel):
|
29
|
+
task_name: str
|
30
|
+
index: int
|
31
|
+
conversation: Conversation[ResultT]
|
32
|
+
last_hash: int
|
33
|
+
|
34
|
+
@property
|
35
|
+
def result(self) -> ResultT:
|
36
|
+
if len(self.conversation.result_messages) == 0:
|
37
|
+
raise Honk("Node awaiting response, has no result")
|
38
|
+
|
39
|
+
return self.conversation.result_messages[-1]
|
40
|
+
|
41
|
+
def set_context(self, *, context: SystemMessage) -> Self:
|
42
|
+
self.conversation.context = context
|
43
|
+
return self
|
44
|
+
|
45
|
+
def add_result(
|
46
|
+
self,
|
47
|
+
*,
|
48
|
+
result: ResultT,
|
49
|
+
new_hash: int | None = None,
|
50
|
+
overwrite: bool = False,
|
51
|
+
) -> Self:
|
52
|
+
if overwrite and len(self.conversation.result_messages) > 0:
|
53
|
+
self.conversation.result_messages[-1] = result
|
54
|
+
else:
|
55
|
+
self.conversation.result_messages.append(result)
|
56
|
+
if new_hash is not None:
|
57
|
+
self.last_hash = new_hash
|
58
|
+
return self
|
59
|
+
|
60
|
+
def add_user_message(self, *, message: UserMessage) -> Self:
|
61
|
+
self.conversation.user_messages.append(message)
|
62
|
+
return self
|
63
|
+
|
64
|
+
|
65
|
+
class FlowRun:
|
66
|
+
def __init__(self) -> None:
|
67
|
+
self._node_states: dict[tuple[str, int], str] = {}
|
68
|
+
self._last_requested_indices: dict[str, int] = {}
|
69
|
+
self._flow_name = ""
|
70
|
+
self._id = ""
|
71
|
+
self._agent: Agent | None = None
|
72
|
+
self._flow_args: tuple[Any, ...] | None = None
|
73
|
+
self._flow_kwargs: dict[str, Any] | None = None
|
74
|
+
|
75
|
+
@property
|
76
|
+
def flow_name(self) -> str:
|
77
|
+
return self._flow_name
|
78
|
+
|
79
|
+
@property
|
80
|
+
def id(self) -> str:
|
81
|
+
return self._id
|
82
|
+
|
83
|
+
@property
|
84
|
+
def agent(self) -> Agent:
|
85
|
+
if self._agent is None:
|
86
|
+
raise Honk("Agent is only accessible once a run is started")
|
87
|
+
return self._agent
|
88
|
+
|
89
|
+
@property
|
90
|
+
def flow_inputs(self) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
91
|
+
if self._flow_args is None or self._flow_kwargs is None:
|
92
|
+
raise Honk("This Flow run has not been executed before")
|
93
|
+
|
94
|
+
return self._flow_args, self._flow_kwargs
|
95
|
+
|
96
|
+
def get_all[R: Result](self, *, task: "Task[Any, R]") -> list[NodeState[R]]:
|
97
|
+
matching_nodes: list[NodeState[R]] = []
|
98
|
+
for key, node_state in self._node_states.items():
|
99
|
+
if key[0] == task.name:
|
100
|
+
matching_nodes.append(
|
101
|
+
NodeState[task.result_type].model_validate_json(node_state)
|
102
|
+
)
|
103
|
+
return sorted(matching_nodes, key=lambda node: node.index)
|
104
|
+
|
105
|
+
def get[R: Result](self, *, task: "Task[Any, R]", index: int = 0) -> NodeState[R]:
|
106
|
+
if (
|
107
|
+
existing_node_state := self._node_states.get((task.name, index))
|
108
|
+
) is not None:
|
109
|
+
return NodeState[task.result_type].model_validate_json(existing_node_state)
|
110
|
+
else:
|
111
|
+
return NodeState[task.result_type](
|
112
|
+
task_name=task.name,
|
113
|
+
index=index,
|
114
|
+
conversation=Conversation[task.result_type](
|
115
|
+
user_messages=[], result_messages=[]
|
116
|
+
),
|
117
|
+
last_hash=0,
|
118
|
+
)
|
119
|
+
|
120
|
+
def set_flow_inputs(self, *args: Any, **kwargs: Any) -> None:
|
121
|
+
self._flow_args = args
|
122
|
+
self._flow_kwargs = kwargs
|
123
|
+
|
124
|
+
def add_node_state(self, node_state: NodeState[Any], /) -> None:
|
125
|
+
key = (node_state.task_name, node_state.index)
|
126
|
+
self._node_states[key] = node_state.model_dump_json()
|
127
|
+
|
128
|
+
def get_next[R: Result](self, *, task: "Task[Any, R]") -> NodeState[R]:
|
129
|
+
if task.name not in self._last_requested_indices:
|
130
|
+
self._last_requested_indices[task.name] = 0
|
131
|
+
else:
|
132
|
+
self._last_requested_indices[task.name] += 1
|
133
|
+
|
134
|
+
return self.get(task=task, index=self._last_requested_indices[task.name])
|
135
|
+
|
136
|
+
def start(
|
137
|
+
self,
|
138
|
+
*,
|
139
|
+
flow_name: str,
|
140
|
+
run_id: str,
|
141
|
+
agent_logger: IAgentLogger | None = None,
|
142
|
+
) -> None:
|
143
|
+
self._last_requested_indices = {}
|
144
|
+
self._flow_name = flow_name
|
145
|
+
self._id = run_id
|
146
|
+
self._agent = Agent(
|
147
|
+
flow_name=self.flow_name, run_id=self.id, logger=agent_logger
|
148
|
+
)
|
149
|
+
|
150
|
+
def end(self) -> None:
|
151
|
+
self._last_requested_indices = {}
|
152
|
+
self._flow_name = ""
|
153
|
+
self._id = ""
|
154
|
+
self._agent = None
|
155
|
+
|
156
|
+
def clear_node(self, *, task: "Task[Any, Result]", index: int) -> None:
|
157
|
+
key = (task.name, index)
|
158
|
+
if key in self._node_states:
|
159
|
+
del self._node_states[key]
|
160
|
+
|
161
|
+
def dump(self) -> FlowRunState:
|
162
|
+
flow_args, flow_kwargs = self.flow_inputs
|
163
|
+
|
164
|
+
return FlowRunState(
|
165
|
+
node_states=self._node_states,
|
166
|
+
flow_args=flow_args,
|
167
|
+
flow_kwargs=flow_kwargs,
|
168
|
+
)
|
169
|
+
|
170
|
+
@classmethod
|
171
|
+
def load(cls, flow_run_state: FlowRunState, /) -> Self:
|
172
|
+
flow_run = cls()
|
173
|
+
flow_run._node_states = flow_run_state.node_states
|
174
|
+
flow_run._flow_args = flow_run_state.flow_args
|
175
|
+
flow_run._flow_kwargs = flow_run_state.flow_kwargs
|
176
|
+
|
177
|
+
return flow_run
|
178
|
+
|
179
|
+
|
180
|
+
_current_flow_run: ContextVar[FlowRun | None] = ContextVar(
|
181
|
+
"current_flow_run", default=None
|
182
|
+
)
|
183
|
+
|
184
|
+
|
185
|
+
def get_current_flow_run() -> FlowRun | None:
|
186
|
+
return _current_flow_run.get()
|
187
|
+
|
188
|
+
|
189
|
+
def set_current_flow_run(flow_run: FlowRun | None) -> None:
|
190
|
+
_current_flow_run.set(flow_run)
|
@@ -1,9 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import
|
3
|
+
from typing import Protocol
|
4
4
|
|
5
|
-
|
6
|
-
|
5
|
+
from goose._internal.flow import FlowRun
|
6
|
+
from goose._internal.state import FlowRunState
|
7
7
|
|
8
8
|
|
9
9
|
class IFlowRunStore(Protocol):
|
@@ -16,13 +16,15 @@ class IFlowRunStore(Protocol):
|
|
16
16
|
class InMemoryFlowRunStore(IFlowRunStore):
|
17
17
|
def __init__(self, *, flow_name: str) -> None:
|
18
18
|
self._flow_name = flow_name
|
19
|
-
self._runs: dict[str,
|
19
|
+
self._runs: dict[str, FlowRunState] = {}
|
20
20
|
|
21
21
|
async def get(self, *, run_id: str) -> FlowRun | None:
|
22
|
-
|
22
|
+
state = self._runs.get(run_id)
|
23
|
+
if state is not None:
|
24
|
+
return FlowRun.load(state)
|
23
25
|
|
24
26
|
async def save(self, *, run: FlowRun) -> None:
|
25
|
-
self._runs[run.id] = run
|
27
|
+
self._runs[run.id] = run.dump()
|
26
28
|
|
27
29
|
async def delete(self, *, run_id: str) -> None:
|
28
30
|
self._runs.pop(run_id, None)
|