goose-py 0.5.1__tar.gz → 0.7.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- goose_py-0.7.0/.envrc +1 -0
- goose_py-0.7.0/.github/workflows/publish.yml +75 -0
- goose_py-0.7.0/.gitignore +5 -0
- goose_py-0.7.0/.python-version +1 -0
- goose_py-0.7.0/.stubs/jsonpath_ng/__init__.pyi +14 -0
- goose_py-0.7.0/.stubs/litellm/__init__.pyi +61 -0
- goose_py-0.7.0/PKG-INFO +14 -0
- goose_py-0.7.0/goose/__init__.py +4 -0
- {goose_py-0.5.1/goose → goose_py-0.7.0/goose/_internal}/agent.py +8 -90
- goose_py-0.7.0/goose/_internal/conversation.py +41 -0
- goose_py-0.7.0/goose/_internal/flow.py +106 -0
- goose_py-0.7.0/goose/_internal/state.py +190 -0
- {goose_py-0.5.1/goose → goose_py-0.7.0/goose/_internal}/store.py +8 -6
- goose_py-0.7.0/goose/_internal/task.py +136 -0
- goose_py-0.7.0/goose/_internal/types/agent.py +92 -0
- goose_py-0.7.0/goose/agent.py +28 -0
- goose_py-0.7.0/goose/flow.py +3 -0
- goose_py-0.7.0/goose/runs.py +4 -0
- goose_py-0.7.0/pyproject.toml +80 -0
- goose_py-0.7.0/tests/__init__.py +0 -0
- goose_py-0.7.0/tests/conftest.py +9 -0
- goose_py-0.7.0/tests/test_agent.py +72 -0
- goose_py-0.7.0/tests/test_complex_flow_arguments.py +23 -0
- goose_py-0.7.0/tests/test_downstream_task.py +37 -0
- goose_py-0.7.0/tests/test_jamming.py +78 -0
- goose_py-0.7.0/tests/test_looping.py +67 -0
- goose_py-0.7.0/tests/test_regenerate.py +48 -0
- goose_py-0.7.0/tests/test_state.py +51 -0
- goose_py-0.7.0/uv.lock +1057 -0
- goose_py-0.5.1/PKG-INFO +0 -31
- goose_py-0.5.1/goose/flow.py +0 -458
- goose_py-0.5.1/pyproject.toml +0 -62
- {goose_py-0.5.1 → goose_py-0.7.0}/README.md +0 -0
- {goose_py-0.5.1/goose → goose_py-0.7.0/goose/_internal}/result.py +0 -0
- {goose_py-0.5.1/goose → goose_py-0.7.0/goose/_internal/types}/__init__.py +0 -0
- {goose_py-0.5.1 → goose_py-0.7.0}/goose/errors.py +0 -0
- {goose_py-0.5.1 → goose_py-0.7.0}/goose/py.typed +0 -0
goose_py-0.7.0/.envrc
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
dotenv
|
@@ -0,0 +1,75 @@
|
|
1
|
+
name: CI/CD
|
2
|
+
|
3
|
+
on:
|
4
|
+
push:
|
5
|
+
branches:
|
6
|
+
- main
|
7
|
+
tags:
|
8
|
+
- "v*"
|
9
|
+
pull_request:
|
10
|
+
branches:
|
11
|
+
- main
|
12
|
+
|
13
|
+
jobs:
|
14
|
+
test:
|
15
|
+
runs-on: ubuntu-latest
|
16
|
+
|
17
|
+
steps:
|
18
|
+
- uses: actions/checkout@v4
|
19
|
+
with:
|
20
|
+
fetch-depth: 0
|
21
|
+
|
22
|
+
- name: Set up Python
|
23
|
+
uses: actions/setup-python@v4
|
24
|
+
with:
|
25
|
+
python-version: "3.12"
|
26
|
+
|
27
|
+
- name: Install Poetry
|
28
|
+
uses: snok/install-poetry@v1
|
29
|
+
with:
|
30
|
+
version: 1.8.5
|
31
|
+
virtualenvs-create: true
|
32
|
+
virtualenvs-in-project: true
|
33
|
+
|
34
|
+
- name: Install dependencies
|
35
|
+
run: poetry install --all-extras
|
36
|
+
|
37
|
+
- name: Run tests
|
38
|
+
run: poetry run pytest tests
|
39
|
+
env:
|
40
|
+
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
|
41
|
+
|
42
|
+
publish:
|
43
|
+
needs: test
|
44
|
+
runs-on: ubuntu-latest
|
45
|
+
if: github.event_name == 'push' && (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main')
|
46
|
+
permissions:
|
47
|
+
id-token: write
|
48
|
+
contents: read
|
49
|
+
|
50
|
+
steps:
|
51
|
+
- uses: actions/checkout@v4
|
52
|
+
with:
|
53
|
+
fetch-depth: 0
|
54
|
+
|
55
|
+
- name: Set up Python
|
56
|
+
uses: actions/setup-python@v4
|
57
|
+
with:
|
58
|
+
python-version: "3.12"
|
59
|
+
|
60
|
+
- name: Install Poetry
|
61
|
+
uses: snok/install-poetry@v1
|
62
|
+
with:
|
63
|
+
version: 1.8.5
|
64
|
+
virtualenvs-create: true
|
65
|
+
virtualenvs-in-project: true
|
66
|
+
|
67
|
+
- name: Install dependencies
|
68
|
+
run: poetry install --all-extras
|
69
|
+
|
70
|
+
- name: Build package
|
71
|
+
run: poetry build
|
72
|
+
|
73
|
+
- name: Publish to PyPI
|
74
|
+
if: startsWith(github.ref, 'refs/tags/')
|
75
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
@@ -0,0 +1 @@
|
|
1
|
+
3.12
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from typing import Any
|
2
|
+
|
3
|
+
class Fields:
|
4
|
+
fields: tuple[str, ...]
|
5
|
+
|
6
|
+
class DatumInContext:
|
7
|
+
value: Any
|
8
|
+
fields: Fields
|
9
|
+
context: "DatumInContext"
|
10
|
+
|
11
|
+
class Jsonpath:
|
12
|
+
def find(self, /, value: Any) -> list[DatumInContext]: ...
|
13
|
+
|
14
|
+
def parse(jsonpath: str, /) -> Jsonpath: ...
|
@@ -0,0 +1,61 @@
|
|
1
|
+
from typing import Any, Literal, NotRequired, TypedDict
|
2
|
+
|
3
|
+
_LiteLLMGeminiModel = Literal[
|
4
|
+
"vertex_ai/gemini-1.5-flash",
|
5
|
+
"vertex_ai/gemini-1.5-pro",
|
6
|
+
"vertex_ai/gemini-1.5-flash-8b",
|
7
|
+
]
|
8
|
+
_MessageRole = Literal["system", "user", "assistant"]
|
9
|
+
|
10
|
+
class _LiteLLMTextMessageContent(TypedDict):
|
11
|
+
type: Literal["text"]
|
12
|
+
text: str
|
13
|
+
|
14
|
+
class _LiteLLMMediaMessageContent(TypedDict):
|
15
|
+
type: Literal["image_url"]
|
16
|
+
image_url: str
|
17
|
+
|
18
|
+
class _LiteLLMCacheControl(TypedDict):
|
19
|
+
type: Literal["ephemeral"]
|
20
|
+
|
21
|
+
class _LiteLLMMessage(TypedDict):
|
22
|
+
role: _MessageRole
|
23
|
+
content: list[_LiteLLMTextMessageContent | _LiteLLMMediaMessageContent]
|
24
|
+
cache_control: NotRequired[_LiteLLMCacheControl]
|
25
|
+
|
26
|
+
class _LiteLLMResponseFormat(TypedDict):
|
27
|
+
type: Literal["json_object"]
|
28
|
+
response_schema: dict[str, Any] # must be a valid JSON schema
|
29
|
+
enforce_validation: NotRequired[bool]
|
30
|
+
|
31
|
+
class _LiteLLMModelResponseChoiceMessage:
|
32
|
+
role: Literal["assistant"]
|
33
|
+
content: str
|
34
|
+
|
35
|
+
class _LiteLLMModelResponseChoice:
|
36
|
+
finish_reason: Literal["stop"]
|
37
|
+
index: int
|
38
|
+
message: _LiteLLMModelResponseChoiceMessage
|
39
|
+
|
40
|
+
class _LiteLLMUsage:
|
41
|
+
completion_tokens: int
|
42
|
+
prompt_tokens: int
|
43
|
+
total_tokens: int
|
44
|
+
|
45
|
+
class ModelResponse:
|
46
|
+
id: str
|
47
|
+
created: int
|
48
|
+
model: _LiteLLMGeminiModel
|
49
|
+
object: Literal["chat.completion"]
|
50
|
+
system_fingerprint: str | None
|
51
|
+
choices: list[_LiteLLMModelResponseChoice]
|
52
|
+
usage: _LiteLLMUsage
|
53
|
+
|
54
|
+
async def acompletion(
|
55
|
+
*,
|
56
|
+
model: _LiteLLMGeminiModel,
|
57
|
+
messages: list[_LiteLLMMessage],
|
58
|
+
response_format: _LiteLLMResponseFormat | None = None,
|
59
|
+
max_tokens: int | None = None,
|
60
|
+
temperature: float = 1.0,
|
61
|
+
) -> ModelResponse: ...
|
goose_py-0.7.0/PKG-INFO
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: goose-py
|
3
|
+
Version: 0.7.0
|
4
|
+
Summary: A tool for AI workflows based on human-computer collaboration and structured output.
|
5
|
+
Author-email: Nash Taylor <nash@chelle.ai>, Joshua Cook <joshua@chelle.ai>, Michael Sankur <michael@chelle.ai>
|
6
|
+
Requires-Python: >=3.12
|
7
|
+
Requires-Dist: jsonpath-ng>=1.7.0
|
8
|
+
Requires-Dist: litellm>=1.56.5
|
9
|
+
Requires-Dist: pydantic>=2.8.2
|
10
|
+
Description-Content-Type: text/markdown
|
11
|
+
|
12
|
+
# Goose
|
13
|
+
|
14
|
+
Docs to come.
|
@@ -1,100 +1,18 @@
|
|
1
1
|
import json
|
2
2
|
import logging
|
3
3
|
from datetime import datetime
|
4
|
-
from
|
5
|
-
from typing import Any, ClassVar, Literal, NotRequired, Protocol, TypedDict
|
4
|
+
from typing import Any, ClassVar, Protocol, TypedDict
|
6
5
|
|
7
6
|
from litellm import acompletion
|
8
7
|
from pydantic import BaseModel, computed_field
|
9
|
-
from goose.result import Result, TextResult
|
10
8
|
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
class UserMediaContentType(StrEnum):
|
19
|
-
# images
|
20
|
-
JPEG = "image/jpeg"
|
21
|
-
PNG = "image/png"
|
22
|
-
WEBP = "image/webp"
|
23
|
-
|
24
|
-
# audio
|
25
|
-
MP3 = "audio/mp3"
|
26
|
-
WAV = "audio/wav"
|
27
|
-
|
28
|
-
# files
|
29
|
-
PDF = "application/pdf"
|
30
|
-
|
31
|
-
|
32
|
-
class LLMTextMessagePart(TypedDict):
|
33
|
-
type: Literal["text"]
|
34
|
-
text: str
|
35
|
-
|
36
|
-
|
37
|
-
class LLMMediaMessagePart(TypedDict):
|
38
|
-
type: Literal["image_url"]
|
39
|
-
image_url: str
|
40
|
-
|
41
|
-
|
42
|
-
class CacheControl(TypedDict):
|
43
|
-
type: Literal["ephemeral"]
|
44
|
-
|
45
|
-
|
46
|
-
class LLMMessage(TypedDict):
|
47
|
-
role: Literal["user", "assistant", "system"]
|
48
|
-
content: list[LLMTextMessagePart | LLMMediaMessagePart]
|
49
|
-
cache_control: NotRequired[CacheControl]
|
50
|
-
|
51
|
-
|
52
|
-
class TextMessagePart(BaseModel):
|
53
|
-
text: str
|
54
|
-
|
55
|
-
def render(self) -> LLMTextMessagePart:
|
56
|
-
return {"type": "text", "text": self.text}
|
57
|
-
|
58
|
-
|
59
|
-
class MediaMessagePart(BaseModel):
|
60
|
-
content_type: UserMediaContentType
|
61
|
-
content: str
|
62
|
-
|
63
|
-
def render(self) -> LLMMediaMessagePart:
|
64
|
-
return {
|
65
|
-
"type": "image_url",
|
66
|
-
"image_url": f"data:{self.content_type};base64,{self.content}",
|
67
|
-
}
|
68
|
-
|
69
|
-
|
70
|
-
class UserMessage(BaseModel):
|
71
|
-
parts: list[TextMessagePart | MediaMessagePart]
|
72
|
-
|
73
|
-
def render(self) -> LLMMessage:
|
74
|
-
content: LLMMessage = {
|
75
|
-
"role": "user",
|
76
|
-
"content": [part.render() for part in self.parts],
|
77
|
-
}
|
78
|
-
if any(isinstance(part, MediaMessagePart) for part in self.parts):
|
79
|
-
content["cache_control"] = {"type": "ephemeral"}
|
80
|
-
return content
|
81
|
-
|
82
|
-
|
83
|
-
class AssistantMessage(BaseModel):
|
84
|
-
text: str
|
85
|
-
|
86
|
-
def render(self) -> LLMMessage:
|
87
|
-
return {"role": "assistant", "content": [{"type": "text", "text": self.text}]}
|
88
|
-
|
89
|
-
|
90
|
-
class SystemMessage(BaseModel):
|
91
|
-
parts: list[TextMessagePart | MediaMessagePart]
|
92
|
-
|
93
|
-
def render(self) -> LLMMessage:
|
94
|
-
return {
|
95
|
-
"role": "system",
|
96
|
-
"content": [part.render() for part in self.parts],
|
97
|
-
}
|
9
|
+
from goose._internal.result import Result, TextResult
|
10
|
+
from goose._internal.types.agent import (
|
11
|
+
AssistantMessage,
|
12
|
+
GeminiModel,
|
13
|
+
SystemMessage,
|
14
|
+
UserMessage,
|
15
|
+
)
|
98
16
|
|
99
17
|
|
100
18
|
class AgentResponseDump(TypedDict):
|
@@ -0,0 +1,41 @@
|
|
1
|
+
from pydantic import BaseModel
|
2
|
+
|
3
|
+
from goose._internal.result import Result
|
4
|
+
from goose._internal.types.agent import (
|
5
|
+
AssistantMessage,
|
6
|
+
LLMMessage,
|
7
|
+
SystemMessage,
|
8
|
+
UserMessage,
|
9
|
+
)
|
10
|
+
|
11
|
+
|
12
|
+
class Conversation[R: Result](BaseModel):
|
13
|
+
user_messages: list[UserMessage]
|
14
|
+
result_messages: list[R]
|
15
|
+
context: SystemMessage | None = None
|
16
|
+
|
17
|
+
@property
|
18
|
+
def awaiting_response(self) -> bool:
|
19
|
+
return len(self.user_messages) == len(self.result_messages)
|
20
|
+
|
21
|
+
def render(self) -> list[LLMMessage]:
|
22
|
+
messages: list[LLMMessage] = []
|
23
|
+
if self.context is not None:
|
24
|
+
messages.append(self.context.render())
|
25
|
+
|
26
|
+
for message_index in range(len(self.user_messages)):
|
27
|
+
messages.append(
|
28
|
+
AssistantMessage(
|
29
|
+
text=self.result_messages[message_index].model_dump_json()
|
30
|
+
).render()
|
31
|
+
)
|
32
|
+
messages.append(self.user_messages[message_index].render())
|
33
|
+
|
34
|
+
if len(self.result_messages) > len(self.user_messages):
|
35
|
+
messages.append(
|
36
|
+
AssistantMessage(
|
37
|
+
text=self.result_messages[-1].model_dump_json()
|
38
|
+
).render()
|
39
|
+
)
|
40
|
+
|
41
|
+
return messages
|
@@ -0,0 +1,106 @@
|
|
1
|
+
from contextlib import asynccontextmanager
|
2
|
+
from types import CodeType
|
3
|
+
from typing import AsyncIterator, Awaitable, Callable, Protocol, overload
|
4
|
+
|
5
|
+
from goose._internal.agent import Agent, IAgentLogger
|
6
|
+
from goose._internal.conversation import Conversation
|
7
|
+
from goose._internal.result import Result
|
8
|
+
from goose._internal.state import FlowRun, get_current_flow_run, set_current_flow_run
|
9
|
+
from goose._internal.store import IFlowRunStore, InMemoryFlowRunStore
|
10
|
+
from goose.errors import Honk
|
11
|
+
|
12
|
+
|
13
|
+
class IAdapter[ResultT: Result](Protocol):
|
14
|
+
__code__: CodeType
|
15
|
+
|
16
|
+
async def __call__(
|
17
|
+
self, *, conversation: Conversation[ResultT], agent: Agent
|
18
|
+
) -> ResultT: ...
|
19
|
+
|
20
|
+
|
21
|
+
class Flow[**P]:
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
fn: Callable[P, Awaitable[None]],
|
25
|
+
/,
|
26
|
+
*,
|
27
|
+
name: str | None = None,
|
28
|
+
store: IFlowRunStore | None = None,
|
29
|
+
agent_logger: IAgentLogger | None = None,
|
30
|
+
) -> None:
|
31
|
+
self._fn = fn
|
32
|
+
self._name = name
|
33
|
+
self._agent_logger = agent_logger
|
34
|
+
self._store = store or InMemoryFlowRunStore(flow_name=self.name)
|
35
|
+
|
36
|
+
@property
|
37
|
+
def name(self) -> str:
|
38
|
+
return self._name or self._fn.__name__
|
39
|
+
|
40
|
+
@property
|
41
|
+
def current_run(self) -> FlowRun:
|
42
|
+
run = get_current_flow_run()
|
43
|
+
if run is None:
|
44
|
+
raise Honk("No current flow run")
|
45
|
+
return run
|
46
|
+
|
47
|
+
@asynccontextmanager
|
48
|
+
async def start_run(self, *, run_id: str) -> AsyncIterator[FlowRun]:
|
49
|
+
existing_run = await self._store.get(run_id=run_id)
|
50
|
+
if existing_run is None:
|
51
|
+
run = FlowRun()
|
52
|
+
else:
|
53
|
+
run = existing_run
|
54
|
+
|
55
|
+
old_run = get_current_flow_run()
|
56
|
+
set_current_flow_run(run)
|
57
|
+
|
58
|
+
run.start(flow_name=self.name, run_id=run_id, agent_logger=self._agent_logger)
|
59
|
+
yield run
|
60
|
+
await self._store.save(run=run)
|
61
|
+
run.end()
|
62
|
+
|
63
|
+
set_current_flow_run(old_run)
|
64
|
+
|
65
|
+
async def generate(self, *args: P.args, **kwargs: P.kwargs) -> None:
|
66
|
+
flow_run = get_current_flow_run()
|
67
|
+
if flow_run is None:
|
68
|
+
raise Honk("No current flow run")
|
69
|
+
|
70
|
+
flow_run.set_flow_inputs(*args, **kwargs)
|
71
|
+
await self._fn(*args, **kwargs)
|
72
|
+
|
73
|
+
async def regenerate(self) -> None:
|
74
|
+
flow_run = get_current_flow_run()
|
75
|
+
if flow_run is None:
|
76
|
+
raise Honk("No current flow run")
|
77
|
+
|
78
|
+
flow_args, flow_kwargs = flow_run.flow_inputs
|
79
|
+
await self._fn(*flow_args, **flow_kwargs)
|
80
|
+
|
81
|
+
|
82
|
+
@overload
|
83
|
+
def flow[**P](fn: Callable[P, Awaitable[None]], /) -> Flow[P]: ...
|
84
|
+
@overload
|
85
|
+
def flow[**P](
|
86
|
+
*,
|
87
|
+
name: str | None = None,
|
88
|
+
store: IFlowRunStore | None = None,
|
89
|
+
agent_logger: IAgentLogger | None = None,
|
90
|
+
) -> Callable[[Callable[P, Awaitable[None]]], Flow[P]]: ...
|
91
|
+
def flow[**P](
|
92
|
+
fn: Callable[P, Awaitable[None]] | None = None,
|
93
|
+
/,
|
94
|
+
*,
|
95
|
+
name: str | None = None,
|
96
|
+
store: IFlowRunStore | None = None,
|
97
|
+
agent_logger: IAgentLogger | None = None,
|
98
|
+
) -> Flow[P] | Callable[[Callable[P, Awaitable[None]]], Flow[P]]:
|
99
|
+
if fn is None:
|
100
|
+
|
101
|
+
def decorator(fn: Callable[P, Awaitable[None]]) -> Flow[P]:
|
102
|
+
return Flow(fn, name=name, store=store, agent_logger=agent_logger)
|
103
|
+
|
104
|
+
return decorator
|
105
|
+
|
106
|
+
return Flow(fn, name=name, store=store, agent_logger=agent_logger)
|
@@ -0,0 +1,190 @@
|
|
1
|
+
from contextvars import ContextVar
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from typing import TYPE_CHECKING, Any, Self
|
4
|
+
|
5
|
+
from pydantic import BaseModel
|
6
|
+
|
7
|
+
from goose._internal.agent import (
|
8
|
+
Agent,
|
9
|
+
IAgentLogger,
|
10
|
+
SystemMessage,
|
11
|
+
UserMessage,
|
12
|
+
)
|
13
|
+
from goose._internal.conversation import Conversation
|
14
|
+
from goose._internal.result import Result
|
15
|
+
from goose.errors import Honk
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
from goose._internal.task import Task
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass
|
22
|
+
class FlowRunState:
|
23
|
+
node_states: dict[tuple[str, int], str]
|
24
|
+
flow_args: tuple[Any, ...]
|
25
|
+
flow_kwargs: dict[str, Any]
|
26
|
+
|
27
|
+
|
28
|
+
class NodeState[ResultT: Result](BaseModel):
|
29
|
+
task_name: str
|
30
|
+
index: int
|
31
|
+
conversation: Conversation[ResultT]
|
32
|
+
last_hash: int
|
33
|
+
|
34
|
+
@property
|
35
|
+
def result(self) -> ResultT:
|
36
|
+
if len(self.conversation.result_messages) == 0:
|
37
|
+
raise Honk("Node awaiting response, has no result")
|
38
|
+
|
39
|
+
return self.conversation.result_messages[-1]
|
40
|
+
|
41
|
+
def set_context(self, *, context: SystemMessage) -> Self:
|
42
|
+
self.conversation.context = context
|
43
|
+
return self
|
44
|
+
|
45
|
+
def add_result(
|
46
|
+
self,
|
47
|
+
*,
|
48
|
+
result: ResultT,
|
49
|
+
new_hash: int | None = None,
|
50
|
+
overwrite: bool = False,
|
51
|
+
) -> Self:
|
52
|
+
if overwrite and len(self.conversation.result_messages) > 0:
|
53
|
+
self.conversation.result_messages[-1] = result
|
54
|
+
else:
|
55
|
+
self.conversation.result_messages.append(result)
|
56
|
+
if new_hash is not None:
|
57
|
+
self.last_hash = new_hash
|
58
|
+
return self
|
59
|
+
|
60
|
+
def add_user_message(self, *, message: UserMessage) -> Self:
|
61
|
+
self.conversation.user_messages.append(message)
|
62
|
+
return self
|
63
|
+
|
64
|
+
|
65
|
+
class FlowRun:
|
66
|
+
def __init__(self) -> None:
|
67
|
+
self._node_states: dict[tuple[str, int], str] = {}
|
68
|
+
self._last_requested_indices: dict[str, int] = {}
|
69
|
+
self._flow_name = ""
|
70
|
+
self._id = ""
|
71
|
+
self._agent: Agent | None = None
|
72
|
+
self._flow_args: tuple[Any, ...] | None = None
|
73
|
+
self._flow_kwargs: dict[str, Any] | None = None
|
74
|
+
|
75
|
+
@property
|
76
|
+
def flow_name(self) -> str:
|
77
|
+
return self._flow_name
|
78
|
+
|
79
|
+
@property
|
80
|
+
def id(self) -> str:
|
81
|
+
return self._id
|
82
|
+
|
83
|
+
@property
|
84
|
+
def agent(self) -> Agent:
|
85
|
+
if self._agent is None:
|
86
|
+
raise Honk("Agent is only accessible once a run is started")
|
87
|
+
return self._agent
|
88
|
+
|
89
|
+
@property
|
90
|
+
def flow_inputs(self) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
91
|
+
if self._flow_args is None or self._flow_kwargs is None:
|
92
|
+
raise Honk("This Flow run has not been executed before")
|
93
|
+
|
94
|
+
return self._flow_args, self._flow_kwargs
|
95
|
+
|
96
|
+
def get_all[R: Result](self, *, task: "Task[Any, R]") -> list[NodeState[R]]:
|
97
|
+
matching_nodes: list[NodeState[R]] = []
|
98
|
+
for key, node_state in self._node_states.items():
|
99
|
+
if key[0] == task.name:
|
100
|
+
matching_nodes.append(
|
101
|
+
NodeState[task.result_type].model_validate_json(node_state)
|
102
|
+
)
|
103
|
+
return sorted(matching_nodes, key=lambda node: node.index)
|
104
|
+
|
105
|
+
def get[R: Result](self, *, task: "Task[Any, R]", index: int = 0) -> NodeState[R]:
|
106
|
+
if (
|
107
|
+
existing_node_state := self._node_states.get((task.name, index))
|
108
|
+
) is not None:
|
109
|
+
return NodeState[task.result_type].model_validate_json(existing_node_state)
|
110
|
+
else:
|
111
|
+
return NodeState[task.result_type](
|
112
|
+
task_name=task.name,
|
113
|
+
index=index,
|
114
|
+
conversation=Conversation[task.result_type](
|
115
|
+
user_messages=[], result_messages=[]
|
116
|
+
),
|
117
|
+
last_hash=0,
|
118
|
+
)
|
119
|
+
|
120
|
+
def set_flow_inputs(self, *args: Any, **kwargs: Any) -> None:
|
121
|
+
self._flow_args = args
|
122
|
+
self._flow_kwargs = kwargs
|
123
|
+
|
124
|
+
def add_node_state(self, node_state: NodeState[Any], /) -> None:
|
125
|
+
key = (node_state.task_name, node_state.index)
|
126
|
+
self._node_states[key] = node_state.model_dump_json()
|
127
|
+
|
128
|
+
def get_next[R: Result](self, *, task: "Task[Any, R]") -> NodeState[R]:
|
129
|
+
if task.name not in self._last_requested_indices:
|
130
|
+
self._last_requested_indices[task.name] = 0
|
131
|
+
else:
|
132
|
+
self._last_requested_indices[task.name] += 1
|
133
|
+
|
134
|
+
return self.get(task=task, index=self._last_requested_indices[task.name])
|
135
|
+
|
136
|
+
def start(
|
137
|
+
self,
|
138
|
+
*,
|
139
|
+
flow_name: str,
|
140
|
+
run_id: str,
|
141
|
+
agent_logger: IAgentLogger | None = None,
|
142
|
+
) -> None:
|
143
|
+
self._last_requested_indices = {}
|
144
|
+
self._flow_name = flow_name
|
145
|
+
self._id = run_id
|
146
|
+
self._agent = Agent(
|
147
|
+
flow_name=self.flow_name, run_id=self.id, logger=agent_logger
|
148
|
+
)
|
149
|
+
|
150
|
+
def end(self) -> None:
|
151
|
+
self._last_requested_indices = {}
|
152
|
+
self._flow_name = ""
|
153
|
+
self._id = ""
|
154
|
+
self._agent = None
|
155
|
+
|
156
|
+
def clear_node(self, *, task: "Task[Any, Result]", index: int) -> None:
|
157
|
+
key = (task.name, index)
|
158
|
+
if key in self._node_states:
|
159
|
+
del self._node_states[key]
|
160
|
+
|
161
|
+
def dump(self) -> FlowRunState:
|
162
|
+
flow_args, flow_kwargs = self.flow_inputs
|
163
|
+
|
164
|
+
return FlowRunState(
|
165
|
+
node_states=self._node_states,
|
166
|
+
flow_args=flow_args,
|
167
|
+
flow_kwargs=flow_kwargs,
|
168
|
+
)
|
169
|
+
|
170
|
+
@classmethod
|
171
|
+
def load(cls, flow_run_state: FlowRunState, /) -> Self:
|
172
|
+
flow_run = cls()
|
173
|
+
flow_run._node_states = flow_run_state.node_states
|
174
|
+
flow_run._flow_args = flow_run_state.flow_args
|
175
|
+
flow_run._flow_kwargs = flow_run_state.flow_kwargs
|
176
|
+
|
177
|
+
return flow_run
|
178
|
+
|
179
|
+
|
180
|
+
_current_flow_run: ContextVar[FlowRun | None] = ContextVar(
|
181
|
+
"current_flow_run", default=None
|
182
|
+
)
|
183
|
+
|
184
|
+
|
185
|
+
def get_current_flow_run() -> FlowRun | None:
|
186
|
+
return _current_flow_run.get()
|
187
|
+
|
188
|
+
|
189
|
+
def set_current_flow_run(flow_run: FlowRun | None) -> None:
|
190
|
+
_current_flow_run.set(flow_run)
|
@@ -1,9 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import
|
3
|
+
from typing import Protocol
|
4
4
|
|
5
|
-
|
6
|
-
|
5
|
+
from goose._internal.flow import FlowRun
|
6
|
+
from goose._internal.state import FlowRunState
|
7
7
|
|
8
8
|
|
9
9
|
class IFlowRunStore(Protocol):
|
@@ -16,13 +16,15 @@ class IFlowRunStore(Protocol):
|
|
16
16
|
class InMemoryFlowRunStore(IFlowRunStore):
|
17
17
|
def __init__(self, *, flow_name: str) -> None:
|
18
18
|
self._flow_name = flow_name
|
19
|
-
self._runs: dict[str,
|
19
|
+
self._runs: dict[str, FlowRunState] = {}
|
20
20
|
|
21
21
|
async def get(self, *, run_id: str) -> FlowRun | None:
|
22
|
-
|
22
|
+
state = self._runs.get(run_id)
|
23
|
+
if state is not None:
|
24
|
+
return FlowRun.load(state)
|
23
25
|
|
24
26
|
async def save(self, *, run: FlowRun) -> None:
|
25
|
-
self._runs[run.id] = run
|
27
|
+
self._runs[run.id] = run.dump()
|
26
28
|
|
27
29
|
async def delete(self, *, run_id: str) -> None:
|
28
30
|
self._runs.pop(run_id, None)
|