goose-py 0.1.4__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: goose-py
3
- Version: 0.1.4
3
+ Version: 0.2.1
4
4
  Summary: A tool for AI workflows based on human-computer collaboration and structured output.
5
5
  Home-page: https://github.com/chelle-ai/goose
6
6
  Keywords: ai,yaml,configuration,llm
@@ -25,7 +25,7 @@ Project-URL: Documentation, https://github.com/chelle-ai/goose
25
25
  Project-URL: Repository, https://github.com/chelle-ai/goose
26
26
  Description-Content-Type: text/markdown
27
27
 
28
- # Vers
28
+ # Goose
29
29
 
30
30
  Docs to come.
31
31
 
@@ -1,3 +1,3 @@
1
- # Vers
1
+ # Goose
2
2
 
3
3
  Docs to come.
@@ -0,0 +1,3 @@
1
+ from goose.flow import ConversationState, FlowState, Result, flow, task
2
+
3
+ __all__ = ["ConversationState", "FlowState", "Result", "flow", "task"]
@@ -1,8 +1,10 @@
1
1
  import base64
2
+ import logging
2
3
  from datetime import datetime
3
4
  from enum import StrEnum
4
- from typing import ClassVar, Literal, NotRequired, TypedDict
5
+ from typing import Any, Awaitable, Callable, ClassVar, Literal, NotRequired, TypedDict
5
6
 
7
+ from litellm import acompletion
6
8
  from pydantic import BaseModel, computed_field
7
9
 
8
10
 
@@ -131,3 +133,68 @@ class AgentResponse[R: BaseModel](BaseModel):
131
133
  self.OUTPUT_CENTS_PER_MILLION_TOKENS[self.model] * self.output_tokens
132
134
  )
133
135
  return input_cost + output_cost
136
+
137
+
138
+ class Agent:
139
+ def __init__(
140
+ self,
141
+ *,
142
+ flow_name: str,
143
+ run_id: str,
144
+ logger: Callable[[AgentResponse[Any]], Awaitable[None]] | None = None,
145
+ ) -> None:
146
+ self.flow_name = flow_name
147
+ self.run_id = run_id
148
+ self.logger = logger
149
+
150
+ async def __call__[R: BaseModel](
151
+ self,
152
+ *,
153
+ messages: list[UserMessage | AssistantMessage],
154
+ model: GeminiModel,
155
+ response_model: type[R],
156
+ task_name: str,
157
+ system: SystemMessage | None = None,
158
+ ) -> R:
159
+ start_time = datetime.now()
160
+ rendered_messages = [message.render() for message in messages]
161
+ if system is not None:
162
+ rendered_messages.insert(0, system.render())
163
+
164
+ response = await acompletion(
165
+ model=model.value,
166
+ messages=rendered_messages,
167
+ response_format={
168
+ "type": "json_object",
169
+ "response_schema": response_model.model_json_schema(),
170
+ "enforce_validation": True,
171
+ },
172
+ )
173
+
174
+ if len(response.choices) == 0:
175
+ raise RuntimeError("No content returned from LLM call.")
176
+
177
+ parsed_response = response_model.model_validate_json(
178
+ response.choices[0].message.content
179
+ )
180
+ end_time = datetime.now()
181
+ agent_response = AgentResponse(
182
+ response=parsed_response,
183
+ id=self.run_id,
184
+ flow_name=self.flow_name,
185
+ task_name=task_name,
186
+ model=model,
187
+ system=system,
188
+ input_messages=messages,
189
+ input_tokens=response.usage.prompt_tokens,
190
+ output_tokens=response.usage.completion_tokens,
191
+ start_time=start_time,
192
+ end_time=end_time,
193
+ )
194
+
195
+ if self.logger is not None:
196
+ await self.logger(agent_response)
197
+ else:
198
+ logging.info(agent_response.model_dump())
199
+
200
+ return agent_response.response
@@ -0,0 +1,334 @@
1
+ import json
2
+ from contextlib import contextmanager
3
+ from contextvars import ContextVar
4
+ from typing import (
5
+ Any,
6
+ Awaitable,
7
+ Callable,
8
+ Iterator,
9
+ NewType,
10
+ Protocol,
11
+ Self,
12
+ overload,
13
+ )
14
+
15
+ from pydantic import BaseModel, ConfigDict, field_validator
16
+
17
+ from goose.agent import UserMessage
18
+ from goose.errors import Honk
19
+
20
+ SerializedFlowState = NewType("SerializedFlowState", str)
21
+
22
+
23
+ class Result(BaseModel):
24
+ model_config = ConfigDict(frozen=True)
25
+
26
+
27
+ class GooseResponse[R: Result](BaseModel):
28
+ result: R
29
+
30
+
31
+ class ConversationState[R: Result](BaseModel):
32
+ messages: list[UserMessage | GooseResponse[R]]
33
+
34
+ @field_validator("messages")
35
+ def alternates_starting_with_result(
36
+ cls, messages: list[UserMessage | GooseResponse[R]]
37
+ ) -> list[UserMessage | GooseResponse[R]]:
38
+ if len(messages) == 0:
39
+ return messages
40
+ elif isinstance(messages[0], UserMessage):
41
+ raise Honk(
42
+ "User cannot start a conversation on a Task, must begin with a Result"
43
+ )
44
+
45
+ last_message_type: type[UserMessage | GooseResponse[R]] = type(messages[0])
46
+ for message in messages[1:]:
47
+ if isinstance(message, last_message_type):
48
+ raise Honk(
49
+ "Conversation must alternate between User and Result messages"
50
+ )
51
+ last_message_type = type(message)
52
+
53
+ return messages
54
+
55
+ @property
56
+ def awaiting_response(self) -> bool:
57
+ return len(self.messages) % 2 == 0
58
+
59
+
60
+ class IAdapter[ResultT: Result](Protocol):
61
+ async def __call__(
62
+ self, *, conversation_state: ConversationState[ResultT]
63
+ ) -> ResultT: ...
64
+
65
+
66
+ class NodeState[ResultT: Result](BaseModel):
67
+ task_name: str
68
+ index: int
69
+ conversation_state: ConversationState[ResultT]
70
+ last_input_hash: int
71
+ pinned: bool
72
+
73
+ @property
74
+ def result(self) -> ResultT:
75
+ last_message = self.conversation_state.messages[-1]
76
+ if isinstance(last_message, GooseResponse):
77
+ return last_message.result
78
+ else:
79
+ raise Honk("Node awaiting response, has no result")
80
+
81
+ def add_result(
82
+ self,
83
+ *,
84
+ result: ResultT,
85
+ new_input_hash: int | None = None,
86
+ overwrite: bool = False,
87
+ ) -> Self:
88
+ if overwrite:
89
+ if len(self.conversation_state.messages) == 0:
90
+ self.conversation_state.messages.append(GooseResponse(result=result))
91
+ else:
92
+ self.conversation_state.messages[-1] = GooseResponse(result=result)
93
+ else:
94
+ self.conversation_state.messages.append(GooseResponse(result=result))
95
+ if new_input_hash is not None:
96
+ self.last_input_hash = new_input_hash
97
+ return self
98
+
99
+ def add_user_message(self, *, message: UserMessage) -> Self:
100
+ self.conversation_state.messages.append(message)
101
+ return self
102
+
103
+ def pin(self) -> Self:
104
+ self.pinned = True
105
+ return self
106
+
107
+ def unpin(self) -> Self:
108
+ self.pinned = False
109
+ return self
110
+
111
+
112
+ class FlowState:
113
+ def __init__(self) -> None:
114
+ self._node_states: dict[tuple[str, int], str] = {}
115
+ self._last_requested_indices: dict[str, int] = {}
116
+
117
+ def add(self, node_state: NodeState[Any], /) -> None:
118
+ key = (node_state.task_name, node_state.index)
119
+ self._node_states[key] = node_state.model_dump_json()
120
+
121
+ def get_next[R: Result](self, *, task: "Task[Any, R]") -> NodeState[R]:
122
+ if task.name not in self._last_requested_indices:
123
+ self._last_requested_indices[task.name] = 0
124
+ else:
125
+ self._last_requested_indices[task.name] += 1
126
+
127
+ return self.get(task=task, index=self._last_requested_indices[task.name])
128
+
129
+ def get_all[R: Result](self, *, task: "Task[Any, R]") -> list[NodeState[R]]:
130
+ matching_nodes: list[NodeState[R]] = []
131
+ for key, node_state in self._node_states.items():
132
+ if key[0] == task.name:
133
+ matching_nodes.append(
134
+ NodeState[task.result_type].model_validate_json(node_state)
135
+ )
136
+ return matching_nodes
137
+
138
+ def get[R: Result](self, *, task: "Task[Any, R]", index: int = 0) -> NodeState[R]:
139
+ if (
140
+ existing_node_state := self._node_states.get((task.name, index))
141
+ ) is not None:
142
+ return NodeState[task.result_type].model_validate_json(existing_node_state)
143
+ else:
144
+ return NodeState[task.result_type](
145
+ task_name=task.name,
146
+ index=index or 0,
147
+ conversation_state=ConversationState[task.result_type](messages=[]),
148
+ last_input_hash=0,
149
+ pinned=False,
150
+ )
151
+
152
+ @contextmanager
153
+ def run(self) -> Iterator[Self]:
154
+ self._last_requested_indices = {}
155
+ yield self
156
+ self._last_requested_indices = {}
157
+
158
+ def dump(self) -> SerializedFlowState:
159
+ return SerializedFlowState(
160
+ json.dumps(
161
+ {
162
+ ":".join([task_name, str(index)]): value
163
+ for (task_name, index), value in self._node_states.items()
164
+ }
165
+ )
166
+ )
167
+
168
+ @classmethod
169
+ def load(cls, state: SerializedFlowState) -> Self:
170
+ flow_state = cls()
171
+ raw_node_states = json.loads(state)
172
+ new_node_states: dict[tuple[str, int], str] = {}
173
+ for key, node_state in raw_node_states.items():
174
+ task_name, index = tuple(key.split(":"))
175
+ new_node_states[(task_name, int(index))] = node_state
176
+
177
+ flow_state._node_states = new_node_states
178
+ return flow_state
179
+
180
+
181
+ _current_flow_state: ContextVar[FlowState | None] = ContextVar(
182
+ "current_flow_state", default=None
183
+ )
184
+
185
+
186
+ class Flow[**P]:
187
+ def __init__(
188
+ self, fn: Callable[P, Awaitable[None]], /, *, name: str | None = None
189
+ ) -> None:
190
+ self._fn = fn
191
+ self._name = name
192
+
193
+ @property
194
+ def name(self) -> str:
195
+ return self._name or self._fn.__name__
196
+
197
+ @property
198
+ def state(self) -> FlowState:
199
+ state = _current_flow_state.get()
200
+ if state is None:
201
+ raise Honk("No current flow state")
202
+ return state
203
+
204
+ @contextmanager
205
+ def run(self, *, state: FlowState | None = None) -> Iterator[FlowState]:
206
+ if state is None:
207
+ state = FlowState()
208
+
209
+ old_state = _current_flow_state.get()
210
+ _current_flow_state.set(state)
211
+ yield state
212
+ _current_flow_state.set(old_state)
213
+
214
+ async def generate(self, *args: P.args, **kwargs: P.kwargs) -> None:
215
+ with self.state.run():
216
+ await self._fn(*args, **kwargs)
217
+
218
+
219
+ class Task[**P, R: Result]:
220
+ def __init__(
221
+ self,
222
+ generator: Callable[P, Awaitable[R]],
223
+ /,
224
+ *,
225
+ retries: int = 0,
226
+ ) -> None:
227
+ self._generator = generator
228
+ self._adapter: IAdapter[R] | None = None
229
+ self._retries = retries
230
+
231
+ @property
232
+ def result_type(self) -> type[R]:
233
+ result_type = self._generator.__annotations__.get("return")
234
+ if result_type is None:
235
+ raise Honk(f"Task {self.name} has no return type annotation")
236
+ return result_type
237
+
238
+ @property
239
+ def name(self) -> str:
240
+ return self._generator.__name__
241
+
242
+ def adapter(self, adapter: IAdapter[R]) -> Self:
243
+ self._adapter = adapter
244
+ return self
245
+
246
+ async def generate(
247
+ self, state: NodeState[R], *args: P.args, **kwargs: P.kwargs
248
+ ) -> R:
249
+ input_hash = self.__hash_input(*args, **kwargs)
250
+ if input_hash != state.last_input_hash:
251
+ result = await self._generator(*args, **kwargs)
252
+ state.add_result(result=result, new_input_hash=input_hash, overwrite=True)
253
+ return result
254
+ else:
255
+ if not isinstance(state.conversation_state.messages[-1], GooseResponse):
256
+ raise Honk(
257
+ "Conversation must alternate between User and Result messages"
258
+ )
259
+ return state.result
260
+
261
+ async def adapt(
262
+ self, *, flow_state: FlowState, user_message: UserMessage, index: int = 0
263
+ ) -> R:
264
+ node_state = flow_state.get(task=self, index=index)
265
+ if self._adapter is None:
266
+ raise Honk("No adapter provided for Task")
267
+
268
+ node_state.add_user_message(message=user_message)
269
+ result = await self._adapter(conversation_state=node_state.conversation_state)
270
+ node_state.add_result(result=result)
271
+ flow_state.add(node_state)
272
+
273
+ return result
274
+
275
+ async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
276
+ flow_state = self.__get_current_flow_state()
277
+ node_state = flow_state.get_next(task=self)
278
+ result = await self.generate(node_state, *args, **kwargs)
279
+ flow_state.add(node_state)
280
+ return result
281
+
282
+ def __hash_input(self, *args: P.args, **kwargs: P.kwargs) -> int:
283
+ try:
284
+ to_hash = str(tuple(args) + tuple(kwargs.values()))
285
+ return hash(to_hash)
286
+ except TypeError:
287
+ raise Honk(f"Unhashable argument to task {self.name}: {args} {kwargs}")
288
+
289
+ def __get_current_flow_state(self) -> FlowState:
290
+ state = _current_flow_state.get()
291
+ if state is None:
292
+ raise Honk("No current flow state")
293
+ return state
294
+
295
+
296
+ @overload
297
+ def task[**P, R: Result](generator: Callable[P, Awaitable[R]], /) -> Task[P, R]: ...
298
+ @overload
299
+ def task[**P, R: Result](
300
+ *, retries: int = 0
301
+ ) -> Callable[[Callable[P, Awaitable[R]]], Task[P, R]]: ...
302
+ def task[**P, R: Result](
303
+ generator: Callable[P, Awaitable[R]] | None = None,
304
+ /,
305
+ *,
306
+ retries: int = 0,
307
+ ) -> Task[P, R] | Callable[[Callable[P, Awaitable[R]]], Task[P, R]]:
308
+ if generator is None:
309
+
310
+ def decorator(fn: Callable[P, Awaitable[R]]) -> Task[P, R]:
311
+ return Task(fn, retries=retries)
312
+
313
+ return decorator
314
+
315
+ return Task(generator, retries=retries)
316
+
317
+
318
+ @overload
319
+ def flow[**P](fn: Callable[P, Awaitable[None]], /) -> Flow[P]: ...
320
+ @overload
321
+ def flow[**P](
322
+ *, name: str | None = None
323
+ ) -> Callable[[Callable[P, Awaitable[None]]], Flow[P]]: ...
324
+ def flow[**P](
325
+ fn: Callable[P, Awaitable[None]] | None = None, /, *, name: str | None = None
326
+ ) -> Flow[P] | Callable[[Callable[P, Awaitable[None]]], Flow[P]]:
327
+ if fn is None:
328
+
329
+ def decorator(fn: Callable[P, Awaitable[None]]) -> Flow[P]:
330
+ return Flow(fn, name=name)
331
+
332
+ return decorator
333
+
334
+ return Flow(fn, name=name)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "goose-py"
3
- version = "0.1.4"
3
+ version = "0.2.1"
4
4
  description = "A tool for AI workflows based on human-computer collaboration and structured output."
5
5
  authors = [
6
6
  "Nash Taylor <nash@chelle.ai>",
File without changes
@@ -1,75 +0,0 @@
1
- import logging
2
- from datetime import datetime
3
- from typing import Any, Callable
4
-
5
- from litellm import acompletion
6
- from pydantic import BaseModel
7
-
8
- from goose.types import (
9
- AgentResponse,
10
- AssistantMessage,
11
- GeminiModel,
12
- SystemMessage,
13
- UserMessage,
14
- )
15
-
16
-
17
- class Agent:
18
- def __init__(
19
- self,
20
- *,
21
- flow_name: str,
22
- run_id: str,
23
- logger: Callable[[AgentResponse[Any]], None] | None = None,
24
- ) -> None:
25
- self.flow_name = flow_name
26
- self.run_id = run_id
27
- self.logger = logger or logging.info
28
-
29
- async def __call__[R: BaseModel](
30
- self,
31
- *,
32
- messages: list[UserMessage | AssistantMessage],
33
- model: GeminiModel,
34
- response_model: type[R],
35
- task_name: str,
36
- system: SystemMessage | None = None,
37
- ) -> R:
38
- start_time = datetime.now()
39
- rendered_messages = [message.render() for message in messages]
40
- if system is not None:
41
- rendered_messages.insert(0, system.render())
42
-
43
- response = await acompletion(
44
- model=model.value,
45
- messages=rendered_messages,
46
- response_format={
47
- "type": "json_object",
48
- "response_schema": response_model.model_json_schema(),
49
- "enforce_validation": True,
50
- },
51
- )
52
-
53
- if len(response.choices) == 0:
54
- raise RuntimeError("No content returned from LLM call.")
55
-
56
- parsed_response = response_model.model_validate_json(
57
- response.choices[0].message.content
58
- )
59
- end_time = datetime.now()
60
- agent_response = AgentResponse(
61
- response=parsed_response,
62
- id=self.run_id,
63
- flow_name=self.flow_name,
64
- task_name=task_name,
65
- model=model,
66
- system=system,
67
- input_messages=messages,
68
- input_tokens=response.usage.prompt_tokens,
69
- output_tokens=response.usage.completion_tokens,
70
- start_time=start_time,
71
- end_time=end_time,
72
- )
73
-
74
- self.logger(agent_response)
75
- return agent_response.response
@@ -1,45 +0,0 @@
1
- from pydantic import BaseModel
2
-
3
- from goose.types import UserMessage
4
-
5
-
6
- class ConversationState[R: BaseModel](BaseModel):
7
- user_messages: list[UserMessage]
8
- results: list[R]
9
-
10
-
11
- class Conversation[R: BaseModel]:
12
- def __init__(
13
- self,
14
- *,
15
- user_messages: list[UserMessage] | None = None,
16
- results: list[R] | None = None,
17
- ) -> None:
18
- self.user_messages = user_messages or []
19
- self.results = results or []
20
-
21
- @classmethod
22
- def load(cls, *, state: ConversationState[R]) -> "Conversation[R]":
23
- return cls(user_messages=state.user_messages, results=state.results)
24
-
25
- @property
26
- def current_result(self) -> R:
27
- if len(self.results) == 0:
28
- raise RuntimeError("No results in conversation")
29
-
30
- return self.results[-1]
31
-
32
- def add_message(self, *, message: UserMessage) -> None:
33
- self.user_messages.append(message)
34
-
35
- def add_result(self, *, result: R) -> None:
36
- self.results.append(result)
37
-
38
- def replace_last_result(self, *, result: R) -> None:
39
- if len(self.results) == 0:
40
- self.results.append(result)
41
- else:
42
- self.results[-1] = result
43
-
44
- def dump(self) -> ConversationState[R]:
45
- return ConversationState(user_messages=self.user_messages, results=self.results)
@@ -1,294 +0,0 @@
1
- import asyncio
2
- import contextvars
3
- import inspect
4
- from collections import defaultdict
5
- from types import TracebackType
6
- from typing import Any, Awaitable, Callable, Protocol, Self, overload
7
-
8
- from graphlib import TopologicalSorter
9
- from pydantic import BaseModel
10
-
11
- from goose.agent import Agent
12
- from goose.errors import Honk
13
- from goose.conversation import Conversation, ConversationState
14
- from goose.types import AgentResponse, UserMessage
15
-
16
-
17
- class NodeState[R: BaseModel](BaseModel):
18
- name: str
19
- conversation: ConversationState[R]
20
-
21
- @property
22
- def result(self) -> R:
23
- return self.conversation.results[-1]
24
-
25
-
26
- class FlowState(BaseModel):
27
- nodes: list[NodeState[BaseModel]]
28
-
29
-
30
- class NoResult:
31
- pass
32
-
33
-
34
- class IRegenerator[R: BaseModel](Protocol):
35
- async def __call__(self, *, result: R, conversation: Conversation[R]) -> R: ...
36
-
37
-
38
- class Task[**P, R: BaseModel]:
39
- def __init__(
40
- self, generator: Callable[P, Awaitable[R]], /, *, retries: int = 0
41
- ) -> None:
42
- self.retries = retries
43
- self._generator = generator
44
- self._regenerator: IRegenerator[R] | None = None
45
- self._signature = inspect.signature(generator)
46
- self.__validate_fn()
47
-
48
- @property
49
- def result_type(self) -> type[R]:
50
- return_type = self._generator.__annotations__.get("return")
51
- if return_type is None:
52
- raise Honk("Task must have a return type annotation")
53
-
54
- return return_type
55
-
56
- @property
57
- def name(self) -> str:
58
- return self._generator.__name__
59
-
60
- def regenerator(self, regenerator: IRegenerator[R], /) -> Self:
61
- self._regenerator = regenerator
62
- return self
63
-
64
- async def generate(self, *args: P.args, **kwargs: P.kwargs) -> R:
65
- return await self._generator(*args, **kwargs)
66
-
67
- async def regenerate(self, *, result: R, conversation: Conversation[R]) -> R:
68
- if self._regenerator is None:
69
- raise Honk("Task does not have a regenerator implemented")
70
-
71
- return await self._regenerator(result=result, conversation=conversation)
72
-
73
- def __call__(self, *args: P.args, **kwargs: P.kwargs) -> "Node[R]":
74
- arguments = self._signature.bind(*args, **kwargs).arguments
75
- return Node(task=self, arguments=arguments, result_type=self.result_type)
76
-
77
- def __validate_fn(self) -> None:
78
- if any(
79
- param.kind == inspect.Parameter.POSITIONAL_ONLY
80
- for param in self._signature.parameters.values()
81
- ):
82
- raise Honk("Positional-only parameters are not supported in Tasks")
83
-
84
-
85
- class Node[R: BaseModel]:
86
- def __init__(
87
- self,
88
- *,
89
- task: Task[Any, R],
90
- arguments: dict[str, Any],
91
- result_type: type[R],
92
- conversation: Conversation[R] | None = None,
93
- ) -> None:
94
- self.task = task
95
- self.arguments = arguments
96
- self.result_type = result_type
97
- self.conversation = conversation or Conversation[R]()
98
- self.name = task.name
99
-
100
- self._result: R | NoResult = NoResult()
101
- current_flow = Flow.get_current()
102
- if current_flow is None:
103
- raise Honk("Cannot create a node without an active flow")
104
- self.id = current_flow.add_node(node=self)
105
-
106
- @property
107
- def has_result(self) -> bool:
108
- return not isinstance(self._result, NoResult)
109
-
110
- @property
111
- def result(self) -> R:
112
- if isinstance(self._result, NoResult):
113
- raise Honk("Cannot access result of a node before it has run")
114
- return self._result
115
-
116
- async def generate(self) -> None:
117
- self._result = await self.task.generate(**self.arguments)
118
- self.conversation.replace_last_result(result=self.result)
119
-
120
- async def regenerate(self, *, message: UserMessage) -> None:
121
- self.conversation.add_message(message=message)
122
- self._result = await self.task.regenerate(
123
- result=self.result, conversation=self.conversation
124
- )
125
- self.conversation.add_result(result=self.result)
126
-
127
- def dump_state(self) -> NodeState[R]:
128
- return NodeState(name=self.name, conversation=self.conversation.dump())
129
-
130
- def load_state(self, *, state: NodeState[R]) -> None:
131
- self._result = state.result
132
- self.conversation = Conversation[self.result_type].load(
133
- state=state.conversation
134
- )
135
-
136
- def get_inbound_nodes(self) -> list["Node[BaseModel]"]:
137
- def __find_nodes(
138
- obj: Any, visited: set[int] | None = None
139
- ) -> list["Node[BaseModel]"]:
140
- if visited is None:
141
- visited = set()
142
-
143
- if isinstance(obj, Node):
144
- return [obj]
145
- elif isinstance(obj, dict):
146
- return [
147
- node
148
- for value in obj.values()
149
- for node in __find_nodes(value, visited)
150
- ]
151
- elif isinstance(obj, list):
152
- return [node for item in obj for node in __find_nodes(item, visited)]
153
- elif isinstance(obj, tuple):
154
- return [node for item in obj for node in __find_nodes(item, visited)]
155
- elif isinstance(obj, set):
156
- return [node for item in obj for node in __find_nodes(item, visited)]
157
- elif hasattr(obj, "__dict__"):
158
- return [
159
- node
160
- for value in obj.__dict__.values()
161
- for node in __find_nodes(value, visited)
162
- ]
163
- return []
164
-
165
- return __find_nodes(self.arguments)
166
-
167
- def __hash__(self) -> int:
168
- return hash(self.id)
169
-
170
-
171
- class Flow:
172
- _current: contextvars.ContextVar["Flow | None"] = contextvars.ContextVar(
173
- "current_flow", default=None
174
- )
175
-
176
- def __init__(
177
- self,
178
- *,
179
- name: str,
180
- run_id: str,
181
- agent_logger: Callable[[AgentResponse[Any]], None] | None = None,
182
- ) -> None:
183
- self.name = name
184
- self._nodes: list[Node[BaseModel]] = []
185
- self._agent = Agent(flow_name=self.name, run_id=run_id, logger=agent_logger)
186
-
187
- @property
188
- def agent(self) -> Agent:
189
- return self._agent
190
-
191
- def dump_state(self) -> FlowState:
192
- return FlowState(nodes=[node.dump_state() for node in self._nodes])
193
-
194
- def load_state(self, *, flow_state: FlowState) -> None:
195
- nodes_by_name = {node.name: node for node in self._nodes}
196
- for node_state in flow_state.nodes:
197
- matching_node = nodes_by_name.get(node_state.name)
198
- if matching_node is None:
199
- raise Honk(f"Node {node_state.name} from state not found in flow")
200
-
201
- matching_node.load_state(state=node_state)
202
-
203
- async def generate(self) -> None:
204
- graph = {node: node.get_inbound_nodes() for node in self._nodes}
205
- sorter = TopologicalSorter(graph)
206
- sorter.prepare()
207
-
208
- async with asyncio.TaskGroup() as task_group:
209
- while sorter.is_active():
210
- ready_nodes = list(sorter.get_ready())
211
- if ready_nodes:
212
- for node in ready_nodes:
213
- task_group.create_task(node.generate())
214
- sorter.done(*ready_nodes)
215
- else:
216
- await asyncio.sleep(0)
217
-
218
- async def regenerate(self, *, target: Node[Any], message: UserMessage) -> None:
219
- if not target.has_result:
220
- raise Honk("Cannot regenerate a node without a result")
221
-
222
- await target.regenerate(message=message)
223
-
224
- # regenerate all downstream nodes
225
- full_graph = {node: node.get_inbound_nodes() for node in self._nodes}
226
- reversed_graph: dict[Node[BaseModel], set[Node[BaseModel]]] = defaultdict(set)
227
- for node, inbound_nodes in full_graph.items():
228
- for inbound_node in inbound_nodes:
229
- reversed_graph[inbound_node].add(node)
230
-
231
- subgraph: dict[Node[BaseModel], set[Node[BaseModel]]] = defaultdict(set)
232
- queue: list[Node[BaseModel]] = [target]
233
-
234
- while len(queue) > 0:
235
- node = queue.pop(0)
236
- outbound_nodes = reversed_graph[node]
237
- for outbound_node in outbound_nodes:
238
- subgraph[outbound_node].add(node)
239
- if outbound_node not in subgraph:
240
- queue.append(outbound_node)
241
-
242
- if len(subgraph) > 0:
243
- sorter = TopologicalSorter(subgraph)
244
- sorter.prepare()
245
-
246
- async with asyncio.TaskGroup() as task_group:
247
- while sorter.is_active():
248
- ready_nodes = list(sorter.get_ready())
249
- if len(ready_nodes) > 0:
250
- for node in ready_nodes:
251
- if node != target:
252
- task_group.create_task(node.generate())
253
- sorter.done(*ready_nodes)
254
- else:
255
- await asyncio.sleep(0)
256
-
257
- @classmethod
258
- def get_current(cls) -> "Flow | None":
259
- return cls._current.get()
260
-
261
- def add_node(self, *, node: Node[Any]) -> str:
262
- existing_names = [node.name for node in self._nodes]
263
- number = sum(1 for name in existing_names if name == node.name)
264
- self._nodes.append(node)
265
- node_id = f"{node.name}_{number}"
266
- return node_id
267
-
268
- def __enter__(self) -> Self:
269
- if self._current.get() is not None:
270
- raise Honk("Cannot enter a new flow while another flow is already active")
271
- self._current.set(self)
272
- return self
273
-
274
- def __exit__(
275
- self,
276
- exc_type: type[BaseException] | None,
277
- exc_value: BaseException | None,
278
- traceback: TracebackType | None,
279
- ) -> None:
280
- self._current.set(None)
281
-
282
-
283
- @overload
284
- def task[**P, R: BaseModel](fn: Callable[P, Awaitable[R]], /) -> Task[P, R]: ...
285
- @overload
286
- def task[**P, R: BaseModel](
287
- *, retries: int = 0
288
- ) -> Callable[[Callable[P, Awaitable[R]]], Task[P, R]]: ...
289
- def task[**P, R: BaseModel](
290
- fn: Callable[P, Awaitable[R]] | None = None, /, *, retries: int = 0
291
- ) -> Task[P, R] | Callable[[Callable[P, Awaitable[R]]], Task[P, R]]:
292
- if fn is None:
293
- return lambda fn: Task(fn, retries=retries)
294
- return Task(fn, retries=retries)
File without changes
File without changes