goose-py 0.10.2__tar.gz → 0.11.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {goose_py-0.10.2 → goose_py-0.11.1}/PKG-INFO +2 -2
  2. goose_py-0.11.1/goose/__init__.py +7 -0
  3. {goose_py-0.10.2 → goose_py-0.11.1}/goose/_internal/agent.py +75 -84
  4. {goose_py-0.10.2 → goose_py-0.11.1}/goose/_internal/conversation.py +12 -17
  5. {goose_py-0.10.2 → goose_py-0.11.1}/goose/_internal/flow.py +7 -7
  6. {goose_py-0.10.2 → goose_py-0.11.1}/goose/_internal/state.py +51 -35
  7. {goose_py-0.10.2 → goose_py-0.11.1}/goose/_internal/store.py +1 -1
  8. {goose_py-0.10.2 → goose_py-0.11.1}/goose/_internal/task.py +25 -22
  9. {goose_py-0.10.2 → goose_py-0.11.1}/goose/_internal/types/telemetry.py +18 -19
  10. {goose_py-0.10.2 → goose_py-0.11.1}/pyproject.toml +2 -2
  11. {goose_py-0.10.2 → goose_py-0.11.1}/tests/test_agent.py +12 -20
  12. {goose_py-0.10.2 → goose_py-0.11.1}/tests/test_ask.py +17 -23
  13. {goose_py-0.10.2 → goose_py-0.11.1}/tests/test_downstream_task.py +2 -2
  14. {goose_py-0.10.2 → goose_py-0.11.1}/tests/test_hashing.py +3 -3
  15. {goose_py-0.10.2 → goose_py-0.11.1}/tests/test_looping.py +3 -3
  16. {goose_py-0.10.2 → goose_py-0.11.1}/tests/test_refining.py +10 -12
  17. {goose_py-0.10.2 → goose_py-0.11.1}/tests/test_regenerate.py +2 -2
  18. {goose_py-0.10.2 → goose_py-0.11.1}/tests/test_state.py +7 -7
  19. {goose_py-0.10.2 → goose_py-0.11.1}/uv.lock +20 -6
  20. goose_py-0.10.2/.stubs/litellm/__init__.pyi +0 -62
  21. goose_py-0.10.2/goose/__init__.py +0 -6
  22. goose_py-0.10.2/goose/_internal/types/agent.py +0 -101
  23. goose_py-0.10.2/goose/agent.py +0 -26
  24. {goose_py-0.10.2 → goose_py-0.11.1}/.envrc +0 -0
  25. {goose_py-0.10.2 → goose_py-0.11.1}/.github/workflows/publish.yml +0 -0
  26. {goose_py-0.10.2 → goose_py-0.11.1}/.gitignore +0 -0
  27. {goose_py-0.10.2 → goose_py-0.11.1}/.python-version +0 -0
  28. {goose_py-0.10.2 → goose_py-0.11.1}/.stubs/jsonpath_ng/__init__.pyi +0 -0
  29. {goose_py-0.10.2 → goose_py-0.11.1}/Makefile +0 -0
  30. {goose_py-0.10.2 → goose_py-0.11.1}/README.md +0 -0
  31. {goose_py-0.10.2 → goose_py-0.11.1}/goose/_internal/result.py +0 -0
  32. {goose_py-0.10.2 → goose_py-0.11.1}/goose/_internal/types/__init__.py +0 -0
  33. {goose_py-0.10.2 → goose_py-0.11.1}/goose/errors.py +0 -0
  34. {goose_py-0.10.2 → goose_py-0.11.1}/goose/flow.py +0 -0
  35. {goose_py-0.10.2 → goose_py-0.11.1}/goose/py.typed +0 -0
  36. {goose_py-0.10.2 → goose_py-0.11.1}/goose/runs.py +0 -0
  37. {goose_py-0.10.2 → goose_py-0.11.1}/goose/task.py +0 -0
  38. {goose_py-0.10.2 → goose_py-0.11.1}/tests/__init__.py +0 -0
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: goose-py
3
- Version: 0.10.2
3
+ Version: 0.11.1
4
4
  Summary: A tool for AI workflows based on human-computer collaboration and structured output.
5
5
  Author-email: Nash Taylor <nash@chelle.ai>, Joshua Cook <joshua@chelle.ai>, Michael Sankur <michael@chelle.ai>
6
6
  Requires-Python: >=3.12
7
+ Requires-Dist: aikernel>=0.1.8
7
8
  Requires-Dist: jsonpath-ng>=1.7.0
8
- Requires-Dist: litellm>=1.56.5
9
9
  Requires-Dist: pydantic>=2.8.2
10
10
  Description-Content-Type: text/markdown
11
11
 
@@ -0,0 +1,7 @@
1
+ from goose._internal.agent import Agent
2
+ from goose._internal.flow import FlowArguments, flow
3
+ from goose._internal.result import Result, TextResult
4
+ from goose._internal.task import task
5
+ from goose._internal.types.telemetry import AgentResponse
6
+
7
+ __all__ = ["Agent", "flow", "FlowArguments", "Result", "TextResult", "task", "AgentResponse"]
@@ -2,11 +2,17 @@ import logging
2
2
  from datetime import datetime
3
3
  from typing import Any, Literal, Protocol, overload
4
4
 
5
- from litellm import acompletion
5
+ from aikernel import (
6
+ LLMAssistantMessage,
7
+ LLMModel,
8
+ LLMSystemMessage,
9
+ LLMUserMessage,
10
+ llm_structured,
11
+ llm_unstructured,
12
+ )
6
13
  from pydantic import ValidationError
7
14
 
8
15
  from goose._internal.result import FindReplaceResponse, Result, TextResult
9
- from goose._internal.types.agent import AIModel, AssistantMessage, SystemMessage, UserMessage
10
16
  from goose._internal.types.telemetry import AgentResponse
11
17
  from goose.errors import Honk
12
18
 
@@ -30,42 +36,39 @@ class Agent:
30
36
  async def generate[R: Result](
31
37
  self,
32
38
  *,
33
- messages: list[UserMessage | AssistantMessage],
34
- model: AIModel,
39
+ messages: list[LLMUserMessage | LLMAssistantMessage | LLMSystemMessage],
40
+ model: LLMModel,
35
41
  task_name: str,
36
42
  response_model: type[R] = TextResult,
37
- system: SystemMessage | None = None,
38
43
  ) -> R:
39
- rendered_messages = [message.render() for message in messages]
40
- rendered_system = system.render() if system is not None else None
41
-
42
- completion_messages = (
43
- [rendered_system] + rendered_messages if rendered_system is not None else rendered_messages
44
- )
45
-
46
44
  start_time = datetime.now()
45
+
47
46
  if response_model is TextResult:
48
- response = await acompletion(model=model.value, messages=completion_messages)
49
- parsed_response = response_model.model_validate({"text": response.choices[0].message.content})
47
+ response = await llm_unstructured(model=model, messages=messages)
48
+ parsed_response = response_model.model_validate({"text": response.text})
50
49
  else:
51
- response = await acompletion(
52
- model=model.value,
53
- messages=completion_messages,
54
- response_format=response_model,
55
- )
56
- parsed_response = response_model.model_validate_json(response.choices[0].message.content)
50
+ response = await llm_structured(model=model, messages=messages, response_model=response_model)
51
+ parsed_response = response.structured_response
57
52
 
58
53
  end_time = datetime.now()
54
+
55
+ if isinstance(messages[0], LLMSystemMessage):
56
+ system = messages[0].render()
57
+ input_messages = [message.render() for message in messages[1:]]
58
+ else:
59
+ system = None
60
+ input_messages = [message.render() for message in messages]
61
+
59
62
  agent_response = AgentResponse(
60
63
  response=parsed_response,
61
64
  run_id=self.run_id,
62
65
  flow_name=self.flow_name,
63
66
  task_name=task_name,
64
67
  model=model,
65
- system=rendered_system,
66
- input_messages=rendered_messages,
67
- input_tokens=response.usage.prompt_tokens,
68
- output_tokens=response.usage.completion_tokens,
68
+ system=system,
69
+ input_messages=input_messages,
70
+ input_tokens=response.usage.input_tokens,
71
+ output_tokens=response.usage.output_tokens,
69
72
  start_time=start_time,
70
73
  end_time=end_time,
71
74
  )
@@ -80,32 +83,31 @@ class Agent:
80
83
  async def ask(
81
84
  self,
82
85
  *,
83
- messages: list[UserMessage | AssistantMessage],
84
- model: AIModel,
86
+ messages: list[LLMUserMessage | LLMAssistantMessage | LLMSystemMessage],
87
+ model: LLMModel,
85
88
  task_name: str,
86
- system: SystemMessage | None = None,
87
89
  ) -> str:
88
- rendered_messages = [message.render() for message in messages]
89
- rendered_system = system.render() if system is not None else None
90
-
91
- completion_messages = (
92
- [rendered_system] + rendered_messages if rendered_system is not None else rendered_messages
93
- )
94
-
95
90
  start_time = datetime.now()
96
- response = await acompletion(model=model.value, messages=completion_messages)
97
-
91
+ response = await llm_unstructured(model=model, messages=messages)
98
92
  end_time = datetime.now()
93
+
94
+ if isinstance(messages[0], LLMSystemMessage):
95
+ system = messages[0].render()
96
+ input_messages = [message.render() for message in messages[1:]]
97
+ else:
98
+ system = None
99
+ input_messages = [message.render() for message in messages]
100
+
99
101
  agent_response = AgentResponse(
100
- response=response.choices[0].message.content,
102
+ response=response.text,
101
103
  run_id=self.run_id,
102
104
  flow_name=self.flow_name,
103
105
  task_name=task_name,
104
106
  model=model,
105
- system=rendered_system,
106
- input_messages=rendered_messages,
107
- input_tokens=response.usage.prompt_tokens,
108
- output_tokens=response.usage.completion_tokens,
107
+ system=system,
108
+ input_messages=input_messages,
109
+ input_tokens=response.usage.input_tokens,
110
+ output_tokens=response.usage.output_tokens,
109
111
  start_time=start_time,
110
112
  end_time=end_time,
111
113
  )
@@ -115,44 +117,38 @@ class Agent:
115
117
  else:
116
118
  logging.info(agent_response.model_dump())
117
119
 
118
- return response.choices[0].message.content
120
+ return response.text
119
121
 
120
122
  async def refine[R: Result](
121
123
  self,
122
124
  *,
123
- messages: list[UserMessage | AssistantMessage],
124
- model: AIModel,
125
+ messages: list[LLMUserMessage | LLMAssistantMessage | LLMSystemMessage],
126
+ model: LLMModel,
125
127
  task_name: str,
126
128
  response_model: type[R],
127
- system: SystemMessage | None = None,
128
129
  ) -> R:
129
130
  start_time = datetime.now()
131
+ find_replace_response = await llm_structured(model=model, messages=messages, response_model=FindReplaceResponse)
132
+ parsed_find_replace_response = find_replace_response.structured_response
133
+ end_time = datetime.now()
130
134
 
131
- rendered_messages = [message.render() for message in messages]
132
- rendered_system = system.render() if system is not None else None
133
-
134
- completion_messages = (
135
- [rendered_system] + rendered_messages if rendered_system is not None else rendered_messages
136
- )
137
-
138
- find_replace_response = await acompletion(
139
- model=model.value, messages=completion_messages, response_format=FindReplaceResponse
140
- )
141
- parsed_find_replace_response = FindReplaceResponse.model_validate_json(
142
- find_replace_response.choices[0].message.content
143
- )
135
+ if isinstance(messages[0], LLMSystemMessage):
136
+ system = messages[0].render()
137
+ input_messages = [message.render() for message in messages[1:]]
138
+ else:
139
+ system = None
140
+ input_messages = [message.render() for message in messages]
144
141
 
145
- end_time = datetime.now()
146
142
  agent_response = AgentResponse(
147
143
  response=parsed_find_replace_response,
148
144
  run_id=self.run_id,
149
145
  flow_name=self.flow_name,
150
146
  task_name=task_name,
151
147
  model=model,
152
- system=rendered_system,
153
- input_messages=rendered_messages,
154
- input_tokens=find_replace_response.usage.prompt_tokens,
155
- output_tokens=find_replace_response.usage.completion_tokens,
148
+ system=system,
149
+ input_messages=input_messages,
150
+ input_tokens=find_replace_response.usage.input_tokens,
151
+ output_tokens=find_replace_response.usage.output_tokens,
156
152
  start_time=start_time,
157
153
  end_time=end_time,
158
154
  )
@@ -174,69 +170,64 @@ class Agent:
174
170
  async def __call__[R: Result](
175
171
  self,
176
172
  *,
177
- messages: list[UserMessage | AssistantMessage],
178
- model: AIModel,
173
+ messages: list[LLMUserMessage | LLMAssistantMessage | LLMSystemMessage],
174
+ model: LLMModel,
179
175
  task_name: str,
180
176
  mode: Literal["generate"],
181
177
  response_model: type[R],
182
- system: SystemMessage | None = None,
183
178
  ) -> R: ...
184
179
 
185
180
  @overload
186
181
  async def __call__[R: Result](
187
182
  self,
188
183
  *,
189
- messages: list[UserMessage | AssistantMessage],
190
- model: AIModel,
184
+ messages: list[LLMUserMessage | LLMAssistantMessage | LLMSystemMessage],
185
+ model: LLMModel,
191
186
  task_name: str,
192
187
  mode: Literal["ask"],
193
188
  response_model: type[R] = TextResult,
194
- system: SystemMessage | None = None,
195
189
  ) -> str: ...
196
190
 
197
191
  @overload
198
192
  async def __call__[R: Result](
199
193
  self,
200
194
  *,
201
- messages: list[UserMessage | AssistantMessage],
202
- model: AIModel,
195
+ messages: list[LLMUserMessage | LLMAssistantMessage | LLMSystemMessage],
196
+ model: LLMModel,
203
197
  task_name: str,
204
198
  response_model: type[R],
205
199
  mode: Literal["refine"],
206
- system: SystemMessage | None = None,
207
200
  ) -> R: ...
208
201
 
209
202
  @overload
210
203
  async def __call__[R: Result](
211
204
  self,
212
205
  *,
213
- messages: list[UserMessage | AssistantMessage],
214
- model: AIModel,
206
+ messages: list[LLMUserMessage | LLMAssistantMessage | LLMSystemMessage],
207
+ model: LLMModel,
215
208
  task_name: str,
216
209
  response_model: type[R],
217
- system: SystemMessage | None = None,
218
210
  ) -> R: ...
219
211
 
220
212
  async def __call__[R: Result](
221
213
  self,
222
214
  *,
223
- messages: list[UserMessage | AssistantMessage],
224
- model: AIModel,
215
+ messages: list[LLMUserMessage | LLMAssistantMessage | LLMSystemMessage],
216
+ model: LLMModel,
225
217
  task_name: str,
226
218
  response_model: type[R] = TextResult,
227
219
  mode: Literal["generate", "ask", "refine"] = "generate",
228
- system: SystemMessage | None = None,
229
220
  ) -> R | str:
230
221
  match mode:
231
222
  case "generate":
232
223
  return await self.generate(
233
- messages=messages, model=model, task_name=task_name, response_model=response_model, system=system
224
+ messages=messages, model=model, task_name=task_name, response_model=response_model
234
225
  )
235
226
  case "ask":
236
- return await self.ask(messages=messages, model=model, task_name=task_name, system=system)
227
+ return await self.ask(messages=messages, model=model, task_name=task_name)
237
228
  case "refine":
238
229
  return await self.refine(
239
- messages=messages, model=model, task_name=task_name, response_model=response_model, system=system
230
+ messages=messages, model=model, task_name=task_name, response_model=response_model
240
231
  )
241
232
 
242
233
  def __apply_find_replace[R: Result](
@@ -249,12 +240,12 @@ class Agent:
249
240
  return response_model.model_validate_json(dumped_result)
250
241
 
251
242
  def __find_last_result[R: Result](
252
- self, *, messages: list[UserMessage | AssistantMessage], response_model: type[R]
243
+ self, *, messages: list[LLMUserMessage | LLMAssistantMessage | LLMSystemMessage], response_model: type[R]
253
244
  ) -> R:
254
245
  for message in reversed(messages):
255
- if isinstance(message, AssistantMessage):
246
+ if isinstance(message, LLMAssistantMessage):
256
247
  try:
257
- return response_model.model_validate_json(message.text)
248
+ return response_model.model_validate_json(message.parts[0].content)
258
249
  except ValidationError:
259
250
  continue
260
251
 
@@ -1,38 +1,33 @@
1
1
  from typing import Self
2
2
 
3
+ from aikernel import LLMAssistantMessage, LLMSystemMessage, LLMUserMessage
3
4
  from pydantic import BaseModel
4
5
 
5
- from goose._internal.result import Result
6
- from goose._internal.types.agent import AssistantMessage, SystemMessage, UserMessage
7
6
  from goose.errors import Honk
8
7
 
9
8
 
10
- class Conversation[R: Result](BaseModel):
11
- user_messages: list[UserMessage]
12
- assistant_messages: list[R | str]
13
- context: SystemMessage | None = None
9
+ class Conversation(BaseModel):
10
+ user_messages: list[LLMUserMessage]
11
+ assistant_messages: list[LLMAssistantMessage]
12
+ context: LLMSystemMessage | None = None
14
13
 
15
14
  @property
16
15
  def awaiting_response(self) -> bool:
17
16
  return len(self.user_messages) == len(self.assistant_messages)
18
17
 
19
- def get_messages(self) -> list[UserMessage | AssistantMessage]:
20
- messages: list[UserMessage | AssistantMessage] = []
18
+ def render(self) -> list[LLMSystemMessage | LLMUserMessage | LLMAssistantMessage]:
19
+ messages: list[LLMSystemMessage | LLMUserMessage | LLMAssistantMessage] = []
20
+ if self.context is not None:
21
+ messages.append(self.context)
22
+
21
23
  for message_index in range(len(self.user_messages)):
22
24
  message = self.assistant_messages[message_index]
23
- if isinstance(message, str):
24
- messages.append(AssistantMessage(text=message))
25
- else:
26
- messages.append(AssistantMessage(text=message.model_dump_json()))
27
-
25
+ messages.append(message)
28
26
  messages.append(self.user_messages[message_index])
29
27
 
30
28
  if len(self.assistant_messages) > len(self.user_messages):
31
29
  message = self.assistant_messages[-1]
32
- if isinstance(message, str):
33
- messages.append(AssistantMessage(text=message))
34
- else:
35
- messages.append(AssistantMessage(text=message.model_dump_json()))
30
+ messages.append(message)
36
31
 
37
32
  return messages
38
33
 
@@ -3,12 +3,12 @@ from contextlib import asynccontextmanager
3
3
  from types import CodeType
4
4
  from typing import Protocol, overload
5
5
 
6
- from goose._internal.agent import Agent, IAgentLogger
7
- from goose._internal.conversation import Conversation
8
- from goose._internal.result import Result
9
- from goose._internal.state import FlowArguments, FlowRun, get_current_flow_run, set_current_flow_run
10
- from goose._internal.store import IFlowRunStore, InMemoryFlowRunStore
11
- from goose.errors import Honk
6
+ from ..errors import Honk
7
+ from .agent import Agent, IAgentLogger
8
+ from .conversation import Conversation
9
+ from .result import Result
10
+ from .state import FlowArguments, FlowRun, get_current_flow_run, set_current_flow_run
11
+ from .store import IFlowRunStore, InMemoryFlowRunStore
12
12
 
13
13
 
14
14
  class IGenerator[FlowArgumentsT: FlowArguments](Protocol):
@@ -20,7 +20,7 @@ class IGenerator[FlowArgumentsT: FlowArguments](Protocol):
20
20
  class IAdapter[ResultT: Result](Protocol):
21
21
  __code__: CodeType
22
22
 
23
- async def __call__(self, *, conversation: Conversation[ResultT], agent: Agent) -> ResultT: ...
23
+ async def __call__(self, *, conversation: Conversation, agent: Agent) -> ResultT: ...
24
24
 
25
25
 
26
26
  class Flow[FlowArgumentsT: FlowArguments]:
@@ -2,12 +2,12 @@ import json
2
2
  from contextvars import ContextVar
3
3
  from typing import TYPE_CHECKING, Any, NewType, Self
4
4
 
5
+ from aikernel import LLMAssistantMessage, LLMSystemMessage, LLMUserMessage
5
6
  from pydantic import BaseModel, ConfigDict
6
7
 
7
8
  from goose._internal.agent import Agent, IAgentLogger
8
9
  from goose._internal.conversation import Conversation
9
10
  from goose._internal.result import Result
10
- from goose._internal.types.agent import SystemMessage, UserMessage
11
11
  from goose.errors import Honk
12
12
 
13
13
  if TYPE_CHECKING:
@@ -20,55 +20,55 @@ class FlowArguments(BaseModel):
20
20
  model_config = ConfigDict(frozen=True)
21
21
 
22
22
 
23
- class NodeState[ResultT: Result](BaseModel):
23
+ class NodeState(BaseModel):
24
24
  task_name: str
25
25
  index: int
26
- conversation: Conversation[ResultT]
26
+ conversation: Conversation
27
27
  last_hash: int
28
28
 
29
29
  @property
30
- def result(self) -> ResultT:
30
+ def raw_result(self) -> str:
31
31
  for message in reversed(self.conversation.assistant_messages):
32
- if isinstance(message, Result):
33
- return message
32
+ if self.__message_is_result(message):
33
+ return message.parts[0].content
34
34
 
35
35
  raise Honk("Node awaiting response, has no result")
36
36
 
37
- def set_context(self, *, context: SystemMessage) -> Self:
37
+ def set_context(self, *, context: LLMSystemMessage) -> Self:
38
38
  self.conversation.context = context
39
39
  return self
40
40
 
41
41
  def add_result(
42
42
  self,
43
43
  *,
44
- result: ResultT,
44
+ result: str,
45
45
  new_hash: int | None = None,
46
46
  overwrite: bool = False,
47
47
  ) -> Self:
48
48
  if overwrite and len(self.conversation.assistant_messages) > 0:
49
- self.conversation.assistant_messages[-1] = result
49
+ self.conversation.assistant_messages[-1] = LLMAssistantMessage.from_text(result)
50
50
  else:
51
- self.conversation.assistant_messages.append(result)
51
+ self.conversation.assistant_messages.append(LLMAssistantMessage.from_text(result))
52
52
  if new_hash is not None:
53
53
  self.last_hash = new_hash
54
54
  return self
55
55
 
56
56
  def add_answer(self, *, answer: str) -> Self:
57
- self.conversation.assistant_messages.append(answer)
57
+ self.conversation.assistant_messages.append(LLMAssistantMessage.from_text(answer))
58
58
  return self
59
59
 
60
- def add_user_message(self, *, message: UserMessage) -> Self:
60
+ def add_user_message(self, *, message: LLMUserMessage) -> Self:
61
61
  self.conversation.user_messages.append(message)
62
62
  return self
63
63
 
64
- def edit_last_result(self, *, result: ResultT) -> Self:
64
+ def edit_last_result(self, *, result: str) -> Self:
65
65
  if len(self.conversation.assistant_messages) == 0:
66
66
  raise Honk("Node awaiting response, has no result")
67
67
 
68
68
  for message_index, message in enumerate(reversed(self.conversation.assistant_messages)):
69
- if isinstance(message, Result):
69
+ if self.__message_is_result(message):
70
70
  index = len(self.conversation.assistant_messages) - message_index - 1
71
- self.conversation.assistant_messages[index] = result
71
+ self.conversation.assistant_messages[index] = LLMAssistantMessage.from_text(result)
72
72
  return self
73
73
 
74
74
  raise Honk("Node awaiting response, has no result")
@@ -77,6 +77,13 @@ class NodeState[ResultT: Result](BaseModel):
77
77
  self.conversation.undo()
78
78
  return self
79
79
 
80
+ def __message_is_result(self, message: LLMAssistantMessage, /) -> bool:
81
+ try:
82
+ _ = json.loads(message.parts[0].content)
83
+ return True
84
+ except json.JSONDecodeError:
85
+ return False
86
+
80
87
 
81
88
  class FlowRun[FlowArgumentsT: FlowArguments]:
82
89
  def __init__(self, *, flow_arguments_model: type[FlowArgumentsT]) -> None:
@@ -109,38 +116,47 @@ class FlowRun[FlowArgumentsT: FlowArguments]:
109
116
 
110
117
  return self._flow_arguments
111
118
 
112
- def get_all[R: Result](self, *, task: "Task[Any, R]") -> list[NodeState[R]]:
113
- matching_nodes: list[NodeState[R]] = []
114
- for key, node_state in self._node_states.items():
115
- if key[0] == task.name:
116
- matching_nodes.append(NodeState[task.result_type].model_validate_json(node_state))
117
- return sorted(matching_nodes, key=lambda node: node.index)
118
-
119
- def get[R: Result](self, *, task: "Task[Any, R]", index: int = 0) -> NodeState[R]:
119
+ def get_state(self, *, task: "Task[Any, Any]", index: int = 0) -> NodeState:
120
120
  if (existing_node_state := self._node_states.get((task.name, index))) is not None:
121
- return NodeState[task.result_type].model_validate_json(existing_node_state)
121
+ return NodeState.model_validate_json(existing_node_state)
122
122
  else:
123
- return NodeState[task.result_type](
123
+ return NodeState(
124
124
  task_name=task.name,
125
125
  index=index,
126
- conversation=Conversation[task.result_type](user_messages=[], assistant_messages=[]),
126
+ conversation=Conversation(user_messages=[], assistant_messages=[]),
127
127
  last_hash=0,
128
128
  )
129
129
 
130
- def set_flow_arguments(self, flow_arguments: FlowArgumentsT, /) -> None:
131
- self._flow_arguments = flow_arguments
132
-
133
- def upsert_node_state(self, node_state: NodeState[Any], /) -> None:
134
- key = (node_state.task_name, node_state.index)
135
- self._node_states[key] = node_state.model_dump_json()
136
-
137
- def get_next[R: Result](self, *, task: "Task[Any, R]") -> NodeState[R]:
130
+ def get_next_state(self, *, task: "Task[Any, Any]", index: int = 0) -> NodeState:
138
131
  if task.name not in self._last_requested_indices:
139
132
  self._last_requested_indices[task.name] = 0
140
133
  else:
141
134
  self._last_requested_indices[task.name] += 1
142
135
 
143
- return self.get(task=task, index=self._last_requested_indices[task.name])
136
+ return self.get_state(task=task, index=self._last_requested_indices[task.name])
137
+
138
+ def get_all_results[R: Result](self, *, task: "Task[Any, R]") -> list[R]:
139
+ matching_nodes: list[NodeState] = []
140
+ for key, node_state in self._node_states.items():
141
+ if key[0] == task.name:
142
+ matching_nodes.append(NodeState.model_validate_json(node_state))
143
+
144
+ sorted_nodes = sorted(matching_nodes, key=lambda node: node.index)
145
+ return [task.result_type.model_validate_json(node.raw_result) for node in sorted_nodes]
146
+
147
+ def get_result[R: Result](self, *, task: "Task[Any, R]", index: int = 0) -> R:
148
+ if (existing_node_state := self._node_states.get((task.name, index))) is not None:
149
+ parsed_node_state = NodeState.model_validate_json(existing_node_state)
150
+ return task.result_type.model_validate_json(parsed_node_state.raw_result)
151
+ else:
152
+ raise Honk(f"No result found for task {task.name} at index {index}")
153
+
154
+ def set_flow_arguments(self, flow_arguments: FlowArgumentsT, /) -> None:
155
+ self._flow_arguments = flow_arguments
156
+
157
+ def upsert_node_state(self, node_state: NodeState, /) -> None:
158
+ key = (node_state.task_name, node_state.index)
159
+ self._node_states[key] = node_state.model_dump_json()
144
160
 
145
161
  def start(
146
162
  self,
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Protocol
4
4
 
5
- from goose._internal.state import SerializedFlowRun
5
+ from .state import SerializedFlowRun
6
6
 
7
7
 
8
8
  class IFlowRunStore(Protocol):