goose-py 0.5.1__tar.gz → 0.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. goose_py-0.7.0/.envrc +1 -0
  2. goose_py-0.7.0/.github/workflows/publish.yml +75 -0
  3. goose_py-0.7.0/.gitignore +5 -0
  4. goose_py-0.7.0/.python-version +1 -0
  5. goose_py-0.7.0/.stubs/jsonpath_ng/__init__.pyi +14 -0
  6. goose_py-0.7.0/.stubs/litellm/__init__.pyi +61 -0
  7. goose_py-0.7.0/PKG-INFO +14 -0
  8. goose_py-0.7.0/goose/__init__.py +4 -0
  9. {goose_py-0.5.1/goose → goose_py-0.7.0/goose/_internal}/agent.py +8 -90
  10. goose_py-0.7.0/goose/_internal/conversation.py +41 -0
  11. goose_py-0.7.0/goose/_internal/flow.py +106 -0
  12. goose_py-0.7.0/goose/_internal/state.py +190 -0
  13. {goose_py-0.5.1/goose → goose_py-0.7.0/goose/_internal}/store.py +8 -6
  14. goose_py-0.7.0/goose/_internal/task.py +136 -0
  15. goose_py-0.7.0/goose/_internal/types/agent.py +92 -0
  16. goose_py-0.7.0/goose/agent.py +28 -0
  17. goose_py-0.7.0/goose/flow.py +3 -0
  18. goose_py-0.7.0/goose/runs.py +4 -0
  19. goose_py-0.7.0/pyproject.toml +80 -0
  20. goose_py-0.7.0/tests/__init__.py +0 -0
  21. goose_py-0.7.0/tests/conftest.py +9 -0
  22. goose_py-0.7.0/tests/test_agent.py +72 -0
  23. goose_py-0.7.0/tests/test_complex_flow_arguments.py +23 -0
  24. goose_py-0.7.0/tests/test_downstream_task.py +37 -0
  25. goose_py-0.7.0/tests/test_jamming.py +78 -0
  26. goose_py-0.7.0/tests/test_looping.py +67 -0
  27. goose_py-0.7.0/tests/test_regenerate.py +48 -0
  28. goose_py-0.7.0/tests/test_state.py +51 -0
  29. goose_py-0.7.0/uv.lock +1057 -0
  30. goose_py-0.5.1/PKG-INFO +0 -31
  31. goose_py-0.5.1/goose/flow.py +0 -458
  32. goose_py-0.5.1/pyproject.toml +0 -62
  33. {goose_py-0.5.1 → goose_py-0.7.0}/README.md +0 -0
  34. {goose_py-0.5.1/goose → goose_py-0.7.0/goose/_internal}/result.py +0 -0
  35. {goose_py-0.5.1/goose → goose_py-0.7.0/goose/_internal/types}/__init__.py +0 -0
  36. {goose_py-0.5.1 → goose_py-0.7.0}/goose/errors.py +0 -0
  37. {goose_py-0.5.1 → goose_py-0.7.0}/goose/py.typed +0 -0
goose_py-0.7.0/.envrc ADDED
@@ -0,0 +1 @@
1
+ dotenv
@@ -0,0 +1,75 @@
1
+ name: CI/CD
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ tags:
8
+ - "v*"
9
+ pull_request:
10
+ branches:
11
+ - main
12
+
13
+ jobs:
14
+ test:
15
+ runs-on: ubuntu-latest
16
+
17
+ steps:
18
+ - uses: actions/checkout@v4
19
+ with:
20
+ fetch-depth: 0
21
+
22
+ - name: Set up Python
23
+ uses: actions/setup-python@v4
24
+ with:
25
+ python-version: "3.12"
26
+
27
+ - name: Install Poetry
28
+ uses: snok/install-poetry@v1
29
+ with:
30
+ version: 1.8.5
31
+ virtualenvs-create: true
32
+ virtualenvs-in-project: true
33
+
34
+ - name: Install dependencies
35
+ run: poetry install --all-extras
36
+
37
+ - name: Run tests
38
+ run: poetry run pytest tests
39
+ env:
40
+ GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
41
+
42
+ publish:
43
+ needs: test
44
+ runs-on: ubuntu-latest
45
+ if: github.event_name == 'push' && (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main')
46
+ permissions:
47
+ id-token: write
48
+ contents: read
49
+
50
+ steps:
51
+ - uses: actions/checkout@v4
52
+ with:
53
+ fetch-depth: 0
54
+
55
+ - name: Set up Python
56
+ uses: actions/setup-python@v4
57
+ with:
58
+ python-version: "3.12"
59
+
60
+ - name: Install Poetry
61
+ uses: snok/install-poetry@v1
62
+ with:
63
+ version: 1.8.5
64
+ virtualenvs-create: true
65
+ virtualenvs-in-project: true
66
+
67
+ - name: Install dependencies
68
+ run: poetry install --all-extras
69
+
70
+ - name: Build package
71
+ run: poetry build
72
+
73
+ - name: Publish to PyPI
74
+ if: startsWith(github.ref, 'refs/tags/')
75
+ uses: pypa/gh-action-pypi-publish@release/v1
@@ -0,0 +1,5 @@
1
+ __pycache__
2
+ .env
3
+ poetry.lock
4
+ notebooks
5
+ dist
@@ -0,0 +1 @@
1
+ 3.12
@@ -0,0 +1,14 @@
1
+ from typing import Any
2
+
3
+ class Fields:
4
+ fields: tuple[str, ...]
5
+
6
+ class DatumInContext:
7
+ value: Any
8
+ fields: Fields
9
+ context: "DatumInContext"
10
+
11
+ class Jsonpath:
12
+ def find(self, /, value: Any) -> list[DatumInContext]: ...
13
+
14
+ def parse(jsonpath: str, /) -> Jsonpath: ...
@@ -0,0 +1,61 @@
1
+ from typing import Any, Literal, NotRequired, TypedDict
2
+
3
+ _LiteLLMGeminiModel = Literal[
4
+ "vertex_ai/gemini-1.5-flash",
5
+ "vertex_ai/gemini-1.5-pro",
6
+ "vertex_ai/gemini-1.5-flash-8b",
7
+ ]
8
+ _MessageRole = Literal["system", "user", "assistant"]
9
+
10
+ class _LiteLLMTextMessageContent(TypedDict):
11
+ type: Literal["text"]
12
+ text: str
13
+
14
+ class _LiteLLMMediaMessageContent(TypedDict):
15
+ type: Literal["image_url"]
16
+ image_url: str
17
+
18
+ class _LiteLLMCacheControl(TypedDict):
19
+ type: Literal["ephemeral"]
20
+
21
+ class _LiteLLMMessage(TypedDict):
22
+ role: _MessageRole
23
+ content: list[_LiteLLMTextMessageContent | _LiteLLMMediaMessageContent]
24
+ cache_control: NotRequired[_LiteLLMCacheControl]
25
+
26
+ class _LiteLLMResponseFormat(TypedDict):
27
+ type: Literal["json_object"]
28
+ response_schema: dict[str, Any] # must be a valid JSON schema
29
+ enforce_validation: NotRequired[bool]
30
+
31
+ class _LiteLLMModelResponseChoiceMessage:
32
+ role: Literal["assistant"]
33
+ content: str
34
+
35
+ class _LiteLLMModelResponseChoice:
36
+ finish_reason: Literal["stop"]
37
+ index: int
38
+ message: _LiteLLMModelResponseChoiceMessage
39
+
40
+ class _LiteLLMUsage:
41
+ completion_tokens: int
42
+ prompt_tokens: int
43
+ total_tokens: int
44
+
45
+ class ModelResponse:
46
+ id: str
47
+ created: int
48
+ model: _LiteLLMGeminiModel
49
+ object: Literal["chat.completion"]
50
+ system_fingerprint: str | None
51
+ choices: list[_LiteLLMModelResponseChoice]
52
+ usage: _LiteLLMUsage
53
+
54
+ async def acompletion(
55
+ *,
56
+ model: _LiteLLMGeminiModel,
57
+ messages: list[_LiteLLMMessage],
58
+ response_format: _LiteLLMResponseFormat | None = None,
59
+ max_tokens: int | None = None,
60
+ temperature: float = 1.0,
61
+ ) -> ModelResponse: ...
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.4
2
+ Name: goose-py
3
+ Version: 0.7.0
4
+ Summary: A tool for AI workflows based on human-computer collaboration and structured output.
5
+ Author-email: Nash Taylor <nash@chelle.ai>, Joshua Cook <joshua@chelle.ai>, Michael Sankur <michael@chelle.ai>
6
+ Requires-Python: >=3.12
7
+ Requires-Dist: jsonpath-ng>=1.7.0
8
+ Requires-Dist: litellm>=1.56.5
9
+ Requires-Dist: pydantic>=2.8.2
10
+ Description-Content-Type: text/markdown
11
+
12
+ # Goose
13
+
14
+ Docs to come.
@@ -0,0 +1,4 @@
1
+ from goose._internal.agent import Agent
2
+ from goose._internal.flow import flow
3
+ from goose._internal.result import Result, TextResult
4
+ from goose._internal.task import task
@@ -1,100 +1,18 @@
1
1
  import json
2
2
  import logging
3
3
  from datetime import datetime
4
- from enum import StrEnum
5
- from typing import Any, ClassVar, Literal, NotRequired, Protocol, TypedDict
4
+ from typing import Any, ClassVar, Protocol, TypedDict
6
5
 
7
6
  from litellm import acompletion
8
7
  from pydantic import BaseModel, computed_field
9
- from goose.result import Result, TextResult
10
8
 
11
-
12
- class GeminiModel(StrEnum):
13
- PRO = "vertex_ai/gemini-1.5-pro"
14
- FLASH = "vertex_ai/gemini-1.5-flash"
15
- FLASH_8B = "vertex_ai/gemini-1.5-flash-8b"
16
-
17
-
18
- class UserMediaContentType(StrEnum):
19
- # images
20
- JPEG = "image/jpeg"
21
- PNG = "image/png"
22
- WEBP = "image/webp"
23
-
24
- # audio
25
- MP3 = "audio/mp3"
26
- WAV = "audio/wav"
27
-
28
- # files
29
- PDF = "application/pdf"
30
-
31
-
32
- class LLMTextMessagePart(TypedDict):
33
- type: Literal["text"]
34
- text: str
35
-
36
-
37
- class LLMMediaMessagePart(TypedDict):
38
- type: Literal["image_url"]
39
- image_url: str
40
-
41
-
42
- class CacheControl(TypedDict):
43
- type: Literal["ephemeral"]
44
-
45
-
46
- class LLMMessage(TypedDict):
47
- role: Literal["user", "assistant", "system"]
48
- content: list[LLMTextMessagePart | LLMMediaMessagePart]
49
- cache_control: NotRequired[CacheControl]
50
-
51
-
52
- class TextMessagePart(BaseModel):
53
- text: str
54
-
55
- def render(self) -> LLMTextMessagePart:
56
- return {"type": "text", "text": self.text}
57
-
58
-
59
- class MediaMessagePart(BaseModel):
60
- content_type: UserMediaContentType
61
- content: str
62
-
63
- def render(self) -> LLMMediaMessagePart:
64
- return {
65
- "type": "image_url",
66
- "image_url": f"data:{self.content_type};base64,{self.content}",
67
- }
68
-
69
-
70
- class UserMessage(BaseModel):
71
- parts: list[TextMessagePart | MediaMessagePart]
72
-
73
- def render(self) -> LLMMessage:
74
- content: LLMMessage = {
75
- "role": "user",
76
- "content": [part.render() for part in self.parts],
77
- }
78
- if any(isinstance(part, MediaMessagePart) for part in self.parts):
79
- content["cache_control"] = {"type": "ephemeral"}
80
- return content
81
-
82
-
83
- class AssistantMessage(BaseModel):
84
- text: str
85
-
86
- def render(self) -> LLMMessage:
87
- return {"role": "assistant", "content": [{"type": "text", "text": self.text}]}
88
-
89
-
90
- class SystemMessage(BaseModel):
91
- parts: list[TextMessagePart | MediaMessagePart]
92
-
93
- def render(self) -> LLMMessage:
94
- return {
95
- "role": "system",
96
- "content": [part.render() for part in self.parts],
97
- }
9
+ from goose._internal.result import Result, TextResult
10
+ from goose._internal.types.agent import (
11
+ AssistantMessage,
12
+ GeminiModel,
13
+ SystemMessage,
14
+ UserMessage,
15
+ )
98
16
 
99
17
 
100
18
  class AgentResponseDump(TypedDict):
@@ -0,0 +1,41 @@
1
+ from pydantic import BaseModel
2
+
3
+ from goose._internal.result import Result
4
+ from goose._internal.types.agent import (
5
+ AssistantMessage,
6
+ LLMMessage,
7
+ SystemMessage,
8
+ UserMessage,
9
+ )
10
+
11
+
12
+ class Conversation[R: Result](BaseModel):
13
+ user_messages: list[UserMessage]
14
+ result_messages: list[R]
15
+ context: SystemMessage | None = None
16
+
17
+ @property
18
+ def awaiting_response(self) -> bool:
19
+ return len(self.user_messages) == len(self.result_messages)
20
+
21
+ def render(self) -> list[LLMMessage]:
22
+ messages: list[LLMMessage] = []
23
+ if self.context is not None:
24
+ messages.append(self.context.render())
25
+
26
+ for message_index in range(len(self.user_messages)):
27
+ messages.append(
28
+ AssistantMessage(
29
+ text=self.result_messages[message_index].model_dump_json()
30
+ ).render()
31
+ )
32
+ messages.append(self.user_messages[message_index].render())
33
+
34
+ if len(self.result_messages) > len(self.user_messages):
35
+ messages.append(
36
+ AssistantMessage(
37
+ text=self.result_messages[-1].model_dump_json()
38
+ ).render()
39
+ )
40
+
41
+ return messages
@@ -0,0 +1,106 @@
1
+ from contextlib import asynccontextmanager
2
+ from types import CodeType
3
+ from typing import AsyncIterator, Awaitable, Callable, Protocol, overload
4
+
5
+ from goose._internal.agent import Agent, IAgentLogger
6
+ from goose._internal.conversation import Conversation
7
+ from goose._internal.result import Result
8
+ from goose._internal.state import FlowRun, get_current_flow_run, set_current_flow_run
9
+ from goose._internal.store import IFlowRunStore, InMemoryFlowRunStore
10
+ from goose.errors import Honk
11
+
12
+
13
+ class IAdapter[ResultT: Result](Protocol):
14
+ __code__: CodeType
15
+
16
+ async def __call__(
17
+ self, *, conversation: Conversation[ResultT], agent: Agent
18
+ ) -> ResultT: ...
19
+
20
+
21
+ class Flow[**P]:
22
+ def __init__(
23
+ self,
24
+ fn: Callable[P, Awaitable[None]],
25
+ /,
26
+ *,
27
+ name: str | None = None,
28
+ store: IFlowRunStore | None = None,
29
+ agent_logger: IAgentLogger | None = None,
30
+ ) -> None:
31
+ self._fn = fn
32
+ self._name = name
33
+ self._agent_logger = agent_logger
34
+ self._store = store or InMemoryFlowRunStore(flow_name=self.name)
35
+
36
+ @property
37
+ def name(self) -> str:
38
+ return self._name or self._fn.__name__
39
+
40
+ @property
41
+ def current_run(self) -> FlowRun:
42
+ run = get_current_flow_run()
43
+ if run is None:
44
+ raise Honk("No current flow run")
45
+ return run
46
+
47
+ @asynccontextmanager
48
+ async def start_run(self, *, run_id: str) -> AsyncIterator[FlowRun]:
49
+ existing_run = await self._store.get(run_id=run_id)
50
+ if existing_run is None:
51
+ run = FlowRun()
52
+ else:
53
+ run = existing_run
54
+
55
+ old_run = get_current_flow_run()
56
+ set_current_flow_run(run)
57
+
58
+ run.start(flow_name=self.name, run_id=run_id, agent_logger=self._agent_logger)
59
+ yield run
60
+ await self._store.save(run=run)
61
+ run.end()
62
+
63
+ set_current_flow_run(old_run)
64
+
65
+ async def generate(self, *args: P.args, **kwargs: P.kwargs) -> None:
66
+ flow_run = get_current_flow_run()
67
+ if flow_run is None:
68
+ raise Honk("No current flow run")
69
+
70
+ flow_run.set_flow_inputs(*args, **kwargs)
71
+ await self._fn(*args, **kwargs)
72
+
73
+ async def regenerate(self) -> None:
74
+ flow_run = get_current_flow_run()
75
+ if flow_run is None:
76
+ raise Honk("No current flow run")
77
+
78
+ flow_args, flow_kwargs = flow_run.flow_inputs
79
+ await self._fn(*flow_args, **flow_kwargs)
80
+
81
+
82
+ @overload
83
+ def flow[**P](fn: Callable[P, Awaitable[None]], /) -> Flow[P]: ...
84
+ @overload
85
+ def flow[**P](
86
+ *,
87
+ name: str | None = None,
88
+ store: IFlowRunStore | None = None,
89
+ agent_logger: IAgentLogger | None = None,
90
+ ) -> Callable[[Callable[P, Awaitable[None]]], Flow[P]]: ...
91
+ def flow[**P](
92
+ fn: Callable[P, Awaitable[None]] | None = None,
93
+ /,
94
+ *,
95
+ name: str | None = None,
96
+ store: IFlowRunStore | None = None,
97
+ agent_logger: IAgentLogger | None = None,
98
+ ) -> Flow[P] | Callable[[Callable[P, Awaitable[None]]], Flow[P]]:
99
+ if fn is None:
100
+
101
+ def decorator(fn: Callable[P, Awaitable[None]]) -> Flow[P]:
102
+ return Flow(fn, name=name, store=store, agent_logger=agent_logger)
103
+
104
+ return decorator
105
+
106
+ return Flow(fn, name=name, store=store, agent_logger=agent_logger)
@@ -0,0 +1,190 @@
1
+ from contextvars import ContextVar
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING, Any, Self
4
+
5
+ from pydantic import BaseModel
6
+
7
+ from goose._internal.agent import (
8
+ Agent,
9
+ IAgentLogger,
10
+ SystemMessage,
11
+ UserMessage,
12
+ )
13
+ from goose._internal.conversation import Conversation
14
+ from goose._internal.result import Result
15
+ from goose.errors import Honk
16
+
17
+ if TYPE_CHECKING:
18
+ from goose._internal.task import Task
19
+
20
+
21
+ @dataclass
22
+ class FlowRunState:
23
+ node_states: dict[tuple[str, int], str]
24
+ flow_args: tuple[Any, ...]
25
+ flow_kwargs: dict[str, Any]
26
+
27
+
28
+ class NodeState[ResultT: Result](BaseModel):
29
+ task_name: str
30
+ index: int
31
+ conversation: Conversation[ResultT]
32
+ last_hash: int
33
+
34
+ @property
35
+ def result(self) -> ResultT:
36
+ if len(self.conversation.result_messages) == 0:
37
+ raise Honk("Node awaiting response, has no result")
38
+
39
+ return self.conversation.result_messages[-1]
40
+
41
+ def set_context(self, *, context: SystemMessage) -> Self:
42
+ self.conversation.context = context
43
+ return self
44
+
45
+ def add_result(
46
+ self,
47
+ *,
48
+ result: ResultT,
49
+ new_hash: int | None = None,
50
+ overwrite: bool = False,
51
+ ) -> Self:
52
+ if overwrite and len(self.conversation.result_messages) > 0:
53
+ self.conversation.result_messages[-1] = result
54
+ else:
55
+ self.conversation.result_messages.append(result)
56
+ if new_hash is not None:
57
+ self.last_hash = new_hash
58
+ return self
59
+
60
+ def add_user_message(self, *, message: UserMessage) -> Self:
61
+ self.conversation.user_messages.append(message)
62
+ return self
63
+
64
+
65
+ class FlowRun:
66
+ def __init__(self) -> None:
67
+ self._node_states: dict[tuple[str, int], str] = {}
68
+ self._last_requested_indices: dict[str, int] = {}
69
+ self._flow_name = ""
70
+ self._id = ""
71
+ self._agent: Agent | None = None
72
+ self._flow_args: tuple[Any, ...] | None = None
73
+ self._flow_kwargs: dict[str, Any] | None = None
74
+
75
+ @property
76
+ def flow_name(self) -> str:
77
+ return self._flow_name
78
+
79
+ @property
80
+ def id(self) -> str:
81
+ return self._id
82
+
83
+ @property
84
+ def agent(self) -> Agent:
85
+ if self._agent is None:
86
+ raise Honk("Agent is only accessible once a run is started")
87
+ return self._agent
88
+
89
+ @property
90
+ def flow_inputs(self) -> tuple[tuple[Any, ...], dict[str, Any]]:
91
+ if self._flow_args is None or self._flow_kwargs is None:
92
+ raise Honk("This Flow run has not been executed before")
93
+
94
+ return self._flow_args, self._flow_kwargs
95
+
96
+ def get_all[R: Result](self, *, task: "Task[Any, R]") -> list[NodeState[R]]:
97
+ matching_nodes: list[NodeState[R]] = []
98
+ for key, node_state in self._node_states.items():
99
+ if key[0] == task.name:
100
+ matching_nodes.append(
101
+ NodeState[task.result_type].model_validate_json(node_state)
102
+ )
103
+ return sorted(matching_nodes, key=lambda node: node.index)
104
+
105
+ def get[R: Result](self, *, task: "Task[Any, R]", index: int = 0) -> NodeState[R]:
106
+ if (
107
+ existing_node_state := self._node_states.get((task.name, index))
108
+ ) is not None:
109
+ return NodeState[task.result_type].model_validate_json(existing_node_state)
110
+ else:
111
+ return NodeState[task.result_type](
112
+ task_name=task.name,
113
+ index=index,
114
+ conversation=Conversation[task.result_type](
115
+ user_messages=[], result_messages=[]
116
+ ),
117
+ last_hash=0,
118
+ )
119
+
120
+ def set_flow_inputs(self, *args: Any, **kwargs: Any) -> None:
121
+ self._flow_args = args
122
+ self._flow_kwargs = kwargs
123
+
124
+ def add_node_state(self, node_state: NodeState[Any], /) -> None:
125
+ key = (node_state.task_name, node_state.index)
126
+ self._node_states[key] = node_state.model_dump_json()
127
+
128
+ def get_next[R: Result](self, *, task: "Task[Any, R]") -> NodeState[R]:
129
+ if task.name not in self._last_requested_indices:
130
+ self._last_requested_indices[task.name] = 0
131
+ else:
132
+ self._last_requested_indices[task.name] += 1
133
+
134
+ return self.get(task=task, index=self._last_requested_indices[task.name])
135
+
136
+ def start(
137
+ self,
138
+ *,
139
+ flow_name: str,
140
+ run_id: str,
141
+ agent_logger: IAgentLogger | None = None,
142
+ ) -> None:
143
+ self._last_requested_indices = {}
144
+ self._flow_name = flow_name
145
+ self._id = run_id
146
+ self._agent = Agent(
147
+ flow_name=self.flow_name, run_id=self.id, logger=agent_logger
148
+ )
149
+
150
+ def end(self) -> None:
151
+ self._last_requested_indices = {}
152
+ self._flow_name = ""
153
+ self._id = ""
154
+ self._agent = None
155
+
156
+ def clear_node(self, *, task: "Task[Any, Result]", index: int) -> None:
157
+ key = (task.name, index)
158
+ if key in self._node_states:
159
+ del self._node_states[key]
160
+
161
+ def dump(self) -> FlowRunState:
162
+ flow_args, flow_kwargs = self.flow_inputs
163
+
164
+ return FlowRunState(
165
+ node_states=self._node_states,
166
+ flow_args=flow_args,
167
+ flow_kwargs=flow_kwargs,
168
+ )
169
+
170
+ @classmethod
171
+ def load(cls, flow_run_state: FlowRunState, /) -> Self:
172
+ flow_run = cls()
173
+ flow_run._node_states = flow_run_state.node_states
174
+ flow_run._flow_args = flow_run_state.flow_args
175
+ flow_run._flow_kwargs = flow_run_state.flow_kwargs
176
+
177
+ return flow_run
178
+
179
+
180
+ _current_flow_run: ContextVar[FlowRun | None] = ContextVar(
181
+ "current_flow_run", default=None
182
+ )
183
+
184
+
185
+ def get_current_flow_run() -> FlowRun | None:
186
+ return _current_flow_run.get()
187
+
188
+
189
+ def set_current_flow_run(flow_run: FlowRun | None) -> None:
190
+ _current_flow_run.set(flow_run)
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, Protocol
3
+ from typing import Protocol
4
4
 
5
- if TYPE_CHECKING:
6
- from goose.flow import FlowRun
5
+ from goose._internal.flow import FlowRun
6
+ from goose._internal.state import FlowRunState
7
7
 
8
8
 
9
9
  class IFlowRunStore(Protocol):
@@ -16,13 +16,15 @@ class IFlowRunStore(Protocol):
16
16
  class InMemoryFlowRunStore(IFlowRunStore):
17
17
  def __init__(self, *, flow_name: str) -> None:
18
18
  self._flow_name = flow_name
19
- self._runs: dict[str, FlowRun] = {}
19
+ self._runs: dict[str, FlowRunState] = {}
20
20
 
21
21
  async def get(self, *, run_id: str) -> FlowRun | None:
22
- return self._runs.get(run_id)
22
+ state = self._runs.get(run_id)
23
+ if state is not None:
24
+ return FlowRun.load(state)
23
25
 
24
26
  async def save(self, *, run: FlowRun) -> None:
25
- self._runs[run.id] = run
27
+ self._runs[run.id] = run.dump()
26
28
 
27
29
  async def delete(self, *, run_id: str) -> None:
28
30
  self._runs.pop(run_id, None)