goose-py 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,31 @@
1
+ Metadata-Version: 2.1
2
+ Name: goose-py
3
+ Version: 0.1.0
4
+ Summary: A tool for AI workflows based on human-computer collaboration and structured output.
5
+ Home-page: https://github.com/chelle-ai/goose
6
+ Keywords: ai,yaml,configuration,llm
7
+ Author: Nash Taylor
8
+ Author-email: nash@chelle.ai
9
+ Requires-Python: >=3.12,<4.0
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: Operating System :: OS Independent
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Programming Language :: Python :: 3.13
16
+ Provides-Extra: test
17
+ Requires-Dist: ipykernel ; extra == "test"
18
+ Requires-Dist: jsonpath-ng (>=1.7.0,<2.0.0)
19
+ Requires-Dist: litellm (>=1.56.5,<2.0.0)
20
+ Requires-Dist: pydantic (>=2.8.2,<3.0.0)
21
+ Requires-Dist: pytest (<8) ; extra == "test"
22
+ Requires-Dist: pytest-asyncio ; extra == "test"
23
+ Requires-Dist: pytest-mock ; extra == "test"
24
+ Project-URL: Documentation, https://github.com/chelle-ai/goose
25
+ Project-URL: Repository, https://github.com/chelle-ai/goose
26
+ Description-Content-Type: text/markdown
27
+
28
+ # Vers
29
+
30
+ Docs to come.
31
+
@@ -0,0 +1,3 @@
1
+ # Vers
2
+
3
+ Docs to come.
File without changes
@@ -0,0 +1,74 @@
1
+ import logging
2
+ import uuid
3
+ from datetime import datetime
4
+ from typing import Any, Callable
5
+
6
+ from litellm import acompletion
7
+ from pydantic import BaseModel
8
+
9
+ from goose.types import (
10
+ AgentResponse,
11
+ AssistantMessage,
12
+ GeminiModel,
13
+ SystemMessage,
14
+ UserMessage,
15
+ )
16
+
17
+
18
+ class Agent:
19
+ def __init__(
20
+ self,
21
+ *,
22
+ flow_name: str,
23
+ logger: Callable[[AgentResponse[Any]], None] | None = None,
24
+ ) -> None:
25
+ self.flow_name = flow_name
26
+ self.logger = logger or logging.info
27
+
28
+ async def __call__[R: BaseModel](
29
+ self,
30
+ *,
31
+ messages: list[UserMessage | AssistantMessage],
32
+ model: GeminiModel,
33
+ response_model: type[R],
34
+ task_name: str,
35
+ system: SystemMessage | None = None,
36
+ ) -> R:
37
+ start_time = datetime.now()
38
+ rendered_messages = [message.render() for message in messages]
39
+ if system is not None:
40
+ rendered_messages.insert(0, system.render())
41
+
42
+ response = await acompletion(
43
+ model=model.value,
44
+ messages=rendered_messages,
45
+ response_format={
46
+ "type": "json_object",
47
+ "response_schema": response_model.model_json_schema(),
48
+ "enforce_validation": True,
49
+ },
50
+ )
51
+
52
+ if len(response.choices) == 0:
53
+ raise RuntimeError("No content returned from LLM call.")
54
+
55
+ parsed_response = response_model.model_validate_json(
56
+ response.choices[0].message.content
57
+ )
58
+ end_time = datetime.now()
59
+ agent_response = AgentResponse(
60
+ response=parsed_response,
61
+ id=str(uuid.uuid4()),
62
+ flow_name=self.flow_name,
63
+ task_name=task_name,
64
+ model=model,
65
+ system=system,
66
+ input_messages=messages,
67
+ input_tokens=response.usage.prompt_tokens,
68
+ output_tokens=response.usage.completion_tokens,
69
+ start_time=start_time,
70
+ end_time=end_time,
71
+ )
72
+
73
+ self.logger(agent_response)
74
+ return agent_response.response
@@ -0,0 +1,45 @@
1
+ from pydantic import BaseModel
2
+
3
+ from goose.types import UserMessage
4
+
5
+
6
+ class ConversationState[R: BaseModel](BaseModel):
7
+ user_messages: list[UserMessage]
8
+ results: list[R]
9
+
10
+
11
+ class Conversation[R: BaseModel]:
12
+ def __init__(
13
+ self,
14
+ *,
15
+ user_messages: list[UserMessage] | None = None,
16
+ results: list[R] | None = None,
17
+ ) -> None:
18
+ self.user_messages = user_messages or []
19
+ self.results = results or []
20
+
21
+ @classmethod
22
+ def load(cls, *, state: ConversationState[R]) -> "Conversation[R]":
23
+ return cls(user_messages=state.user_messages, results=state.results)
24
+
25
+ @property
26
+ def current_result(self) -> R:
27
+ if len(self.results) == 0:
28
+ raise RuntimeError("No results in conversation")
29
+
30
+ return self.results[-1]
31
+
32
+ def add_message(self, *, message: UserMessage) -> None:
33
+ self.user_messages.append(message)
34
+
35
+ def add_result(self, *, result: R) -> None:
36
+ self.results.append(result)
37
+
38
+ def replace_last_result(self, *, result: R) -> None:
39
+ if len(self.results) == 0:
40
+ self.results.append(result)
41
+ else:
42
+ self.results[-1] = result
43
+
44
+ def dump(self) -> ConversationState[R]:
45
+ return ConversationState(user_messages=self.user_messages, results=self.results)
@@ -0,0 +1,294 @@
1
+ import asyncio
2
+ import contextvars
3
+ import inspect
4
+ from collections import defaultdict
5
+ from types import TracebackType
6
+ from typing import Any, Awaitable, Callable, Protocol, Self, overload
7
+
8
+ from graphlib import TopologicalSorter
9
+ from pydantic import BaseModel
10
+
11
+ from goose.agent import Agent
12
+ from goose.conversation import Conversation, ConversationState
13
+ from goose.regenerator import default_regenerator
14
+ from goose.types import AgentResponse, UserMessage
15
+
16
+
17
+ class NodeState[R: BaseModel](BaseModel):
18
+ name: str
19
+ conversation: ConversationState[R]
20
+
21
+ @property
22
+ def result(self) -> R:
23
+ return self.conversation.results[-1]
24
+
25
+
26
+ class FlowState(BaseModel):
27
+ nodes: list[NodeState[BaseModel]]
28
+
29
+
30
+ class NoResult:
31
+ pass
32
+
33
+
34
+ class IRegenerator[R: BaseModel](Protocol):
35
+ async def __call__(self, *, result: R, conversation: Conversation[R]) -> R: ...
36
+
37
+
38
+ class Task[**P, R: BaseModel]:
39
+ def __init__(
40
+ self, generator: Callable[P, Awaitable[R]], /, *, retries: int = 0
41
+ ) -> None:
42
+ self.retries = retries
43
+ self._generator = generator
44
+ self._regenerator: IRegenerator[R] = default_regenerator
45
+ self._signature = inspect.signature(generator)
46
+ self.__validate_fn()
47
+
48
+ @property
49
+ def result_type(self) -> type[R]:
50
+ return_type = self._generator.__annotations__.get("return")
51
+ if return_type is None:
52
+ raise TypeError("Task must have a return type annotation")
53
+
54
+ return return_type
55
+
56
+ @property
57
+ def name(self) -> str:
58
+ return self._generator.__name__
59
+
60
+ def regenerator(self, regenerator: IRegenerator[R], /) -> Self:
61
+ self._regenerator = regenerator
62
+ return self
63
+
64
+ async def generate(self, *args: P.args, **kwargs: P.kwargs) -> R:
65
+ return await self._generator(*args, **kwargs)
66
+
67
+ async def regenerate(self, *, result: R, conversation: Conversation[R]) -> R:
68
+ return await self._regenerator(result=result, conversation=conversation)
69
+
70
+ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> "Node[R]":
71
+ arguments = self._signature.bind(*args, **kwargs).arguments
72
+ return Node(task=self, arguments=arguments, result_type=self.result_type)
73
+
74
+ def __validate_fn(self) -> None:
75
+ if any(
76
+ param.kind == inspect.Parameter.POSITIONAL_ONLY
77
+ for param in self._signature.parameters.values()
78
+ ):
79
+ raise ValueError("Positional-only parameters are not supported in Tasks")
80
+
81
+
82
+ class Node[R: BaseModel]:
83
+ def __init__(
84
+ self,
85
+ *,
86
+ task: Task[Any, R],
87
+ arguments: dict[str, Any],
88
+ result_type: type[R],
89
+ conversation: Conversation[R] | None = None,
90
+ ) -> None:
91
+ self.task = task
92
+ self.arguments = arguments
93
+ self.result_type = result_type
94
+ self.conversation = conversation or Conversation[R]()
95
+ self.name = task.name
96
+
97
+ self._result: R | NoResult = NoResult()
98
+ current_flow = Flow.get_current()
99
+ if current_flow is None:
100
+ raise RuntimeError("Cannot create a node without an active flow")
101
+ self.id = current_flow.add_node(node=self)
102
+
103
+ @property
104
+ def has_result(self) -> bool:
105
+ return not isinstance(self._result, NoResult)
106
+
107
+ @property
108
+ def result(self) -> R:
109
+ if isinstance(self._result, NoResult):
110
+ raise RuntimeError("Cannot access result of a node before it has run")
111
+ return self._result
112
+
113
+ async def generate(self) -> None:
114
+ self._result = await self.task.generate(**self.arguments)
115
+ self.conversation.replace_last_result(result=self.result)
116
+
117
+ async def regenerate(self, *, message: UserMessage) -> None:
118
+ self.conversation.add_message(message=message)
119
+ self._result = await self.task.regenerate(
120
+ result=self.result, conversation=self.conversation
121
+ )
122
+ self.conversation.add_result(result=self.result)
123
+
124
+ def dump_state(self) -> NodeState[R]:
125
+ return NodeState(name=self.name, conversation=self.conversation.dump())
126
+
127
+ def load_state(self, *, state: NodeState[R]) -> None:
128
+ self._result = state.result
129
+ self.conversation = Conversation[self.result_type].load(
130
+ state=state.conversation
131
+ )
132
+
133
+ def get_inbound_nodes(self) -> list["Node[BaseModel]"]:
134
+ def __find_nodes(
135
+ obj: Any, visited: set[int] | None = None
136
+ ) -> list["Node[BaseModel]"]:
137
+ if visited is None:
138
+ visited = set()
139
+
140
+ if isinstance(obj, Node):
141
+ return [obj]
142
+ elif isinstance(obj, dict):
143
+ return [
144
+ node
145
+ for value in obj.values()
146
+ for node in __find_nodes(value, visited)
147
+ ]
148
+ elif isinstance(obj, list):
149
+ return [node for item in obj for node in __find_nodes(item, visited)]
150
+ elif isinstance(obj, tuple):
151
+ return [node for item in obj for node in __find_nodes(item, visited)]
152
+ elif isinstance(obj, set):
153
+ return [node for item in obj for node in __find_nodes(item, visited)]
154
+ elif hasattr(obj, "__dict__"):
155
+ return [
156
+ node
157
+ for value in obj.__dict__.values()
158
+ for node in __find_nodes(value, visited)
159
+ ]
160
+ return []
161
+
162
+ return __find_nodes(self.arguments)
163
+
164
+ def __hash__(self) -> int:
165
+ return hash(self.id)
166
+
167
+
168
+ class Flow:
169
+ _current: contextvars.ContextVar["Flow | None"] = contextvars.ContextVar(
170
+ "current_flow", default=None
171
+ )
172
+
173
+ def __init__(
174
+ self,
175
+ *,
176
+ name: str,
177
+ agent_logger: Callable[[AgentResponse[Any]], None] | None = None,
178
+ ) -> None:
179
+ self.name = name
180
+ self._nodes: list[Node[BaseModel]] = []
181
+ self._agent = Agent(flow_name=self.name, logger=agent_logger)
182
+
183
+ @property
184
+ def agent(self) -> Agent:
185
+ return self._agent
186
+
187
+ def dump_state(self) -> FlowState:
188
+ return FlowState(nodes=[node.dump_state() for node in self._nodes])
189
+
190
+ def load_state(self, *, flow_state: FlowState) -> None:
191
+ nodes_by_name = {node.name: node for node in self._nodes}
192
+ for node_state in flow_state.nodes:
193
+ matching_node = nodes_by_name.get(node_state.name)
194
+ if matching_node is None:
195
+ raise RuntimeError(
196
+ f"Node {node_state.name} from state not found in flow"
197
+ )
198
+
199
+ matching_node.load_state(state=node_state)
200
+
201
+ async def generate(self) -> None:
202
+ graph = {node: node.get_inbound_nodes() for node in self._nodes}
203
+ sorter = TopologicalSorter(graph)
204
+ sorter.prepare()
205
+
206
+ async with asyncio.TaskGroup() as task_group:
207
+ while sorter.is_active():
208
+ ready_nodes = list(sorter.get_ready())
209
+ if ready_nodes:
210
+ for node in ready_nodes:
211
+ task_group.create_task(node.generate())
212
+ sorter.done(*ready_nodes)
213
+ else:
214
+ await asyncio.sleep(0)
215
+
216
+ async def regenerate(self, *, target: Node[Any], message: UserMessage) -> None:
217
+ if not target.has_result:
218
+ raise RuntimeError("Cannot regenerate a node without a result")
219
+
220
+ await target.regenerate(message=message)
221
+
222
+ # regenerate all downstream nodes
223
+ full_graph = {node: node.get_inbound_nodes() for node in self._nodes}
224
+ reversed_graph: dict[Node[BaseModel], set[Node[BaseModel]]] = defaultdict(set)
225
+ for node, inbound_nodes in full_graph.items():
226
+ for inbound_node in inbound_nodes:
227
+ reversed_graph[inbound_node].add(node)
228
+
229
+ subgraph: dict[Node[BaseModel], set[Node[BaseModel]]] = defaultdict(set)
230
+ queue: list[Node[BaseModel]] = [target]
231
+
232
+ while len(queue) > 0:
233
+ node = queue.pop(0)
234
+ outbound_nodes = reversed_graph[node]
235
+ for outbound_node in outbound_nodes:
236
+ subgraph[outbound_node].add(node)
237
+ if outbound_node not in subgraph:
238
+ queue.append(outbound_node)
239
+
240
+ if len(subgraph) > 0:
241
+ sorter = TopologicalSorter(subgraph)
242
+ sorter.prepare()
243
+
244
+ async with asyncio.TaskGroup() as task_group:
245
+ while sorter.is_active():
246
+ ready_nodes = list(sorter.get_ready())
247
+ if len(ready_nodes) > 0:
248
+ for node in ready_nodes:
249
+ if node != target:
250
+ task_group.create_task(node.generate())
251
+ sorter.done(*ready_nodes)
252
+ else:
253
+ await asyncio.sleep(0)
254
+
255
+ @classmethod
256
+ def get_current(cls) -> "Flow | None":
257
+ return cls._current.get()
258
+
259
+ def add_node(self, *, node: Node[Any]) -> str:
260
+ existing_names = [node.name for node in self._nodes]
261
+ number = sum(1 for name in existing_names if name == node.name)
262
+ self._nodes.append(node)
263
+ node_id = f"{node.name}_{number}"
264
+ return node_id
265
+
266
+ def __enter__(self) -> Self:
267
+ if self._current.get() is not None:
268
+ raise RuntimeError(
269
+ "Cannot enter a new flow while another flow is already active"
270
+ )
271
+ self._current.set(self)
272
+ return self
273
+
274
+ def __exit__(
275
+ self,
276
+ exc_type: type[BaseException] | None,
277
+ exc_value: BaseException | None,
278
+ traceback: TracebackType | None,
279
+ ) -> None:
280
+ self._current.set(None)
281
+
282
+
283
+ @overload
284
+ def task[**P, R: BaseModel](fn: Callable[P, Awaitable[R]], /) -> Task[P, R]: ...
285
+ @overload
286
+ def task[**P, R: BaseModel](
287
+ *, retries: int = 0
288
+ ) -> Callable[[Callable[P, Awaitable[R]]], Task[P, R]]: ...
289
+ def task[**P, R: BaseModel](
290
+ fn: Callable[P, Awaitable[R]] | None = None, /, *, retries: int = 0
291
+ ) -> Task[P, R] | Callable[[Callable[P, Awaitable[R]]], Task[P, R]]:
292
+ if fn is None:
293
+ return lambda fn: Task(fn, retries=retries)
294
+ return Task(fn, retries=retries)
File without changes
@@ -0,0 +1,5 @@
1
+ from goose.conversation import Conversation
2
+
3
+
4
+ async def default_regenerator[R](*, result: R, conversation: Conversation[R]) -> R:
5
+ return result
@@ -0,0 +1,127 @@
1
+ import base64
2
+ from datetime import datetime
3
+ from enum import StrEnum
4
+ from typing import ClassVar, Literal, NotRequired, TypedDict
5
+
6
+ from pydantic import BaseModel, computed_field
7
+
8
+
9
+ class GeminiModel(StrEnum):
10
+ EXP = "gemini/gemini-exp-1121"
11
+ PRO = "gemini/gemini-1.5-pro"
12
+ FLASH = "gemini/gemini-1.5-flash"
13
+ FLASH_8B = "gemini/gemini-1.5-flash-8b"
14
+
15
+
16
+ class UserMediaContentType(StrEnum):
17
+ JPEG = "image/jpeg"
18
+ PNG = "image/png"
19
+ WEBP = "image/webp"
20
+ MP3 = "audio/mpeg"
21
+ WAV = "audio/wav"
22
+
23
+
24
+ class LLMTextMessagePart(TypedDict):
25
+ type: Literal["text"]
26
+ text: str
27
+
28
+
29
+ class LLMMediaMessagePart(TypedDict):
30
+ type: Literal["image_url"]
31
+ image_url: str
32
+
33
+
34
+ class CacheControl(TypedDict):
35
+ type: Literal["ephemeral"]
36
+
37
+
38
+ class LLMMessage(TypedDict):
39
+ role: Literal["user", "assistant", "system"]
40
+ content: list[LLMTextMessagePart | LLMMediaMessagePart]
41
+ cache_control: NotRequired[CacheControl]
42
+
43
+
44
+ class TextMessagePart(BaseModel):
45
+ text: str
46
+
47
+ def render(self) -> LLMTextMessagePart:
48
+ return {"type": "text", "text": self.text}
49
+
50
+
51
+ class MediaMessagePart(BaseModel):
52
+ content_type: UserMediaContentType
53
+ content: bytes
54
+
55
+ def render(self) -> LLMMediaMessagePart:
56
+ return {
57
+ "type": "image_url",
58
+ "image_url": f"data:{self.content_type};base64,{base64.b64encode(self.content).decode()}",
59
+ }
60
+
61
+
62
+ class UserMessage(BaseModel):
63
+ parts: list[TextMessagePart | MediaMessagePart]
64
+
65
+ def render(self) -> LLMMessage:
66
+ content: LLMMessage = {
67
+ "role": "user",
68
+ "content": [part.render() for part in self.parts],
69
+ }
70
+ if any(isinstance(part, MediaMessagePart) for part in self.parts):
71
+ content["cache_control"] = {"type": "ephemeral"}
72
+ return content
73
+
74
+
75
+ class AssistantMessage(BaseModel):
76
+ text: str
77
+
78
+ def render(self) -> LLMMessage:
79
+ return {"role": "assistant", "content": [{"type": "text", "text": self.text}]}
80
+
81
+
82
+ class SystemMessage(BaseModel):
83
+ text: str
84
+
85
+ def render(self) -> LLMMessage:
86
+ return {"role": "system", "content": [{"type": "text", "text": self.text}]}
87
+
88
+
89
+ class AgentResponse[R: BaseModel](BaseModel):
90
+ INPUT_CENTS_PER_MILLION_TOKENS: ClassVar[dict[GeminiModel, float]] = {
91
+ GeminiModel.FLASH_8B: 30,
92
+ GeminiModel.FLASH: 15,
93
+ GeminiModel.PRO: 500,
94
+ GeminiModel.EXP: 0,
95
+ }
96
+ OUTPUT_CENTS_PER_MILLION_TOKENS: ClassVar[dict[GeminiModel, float]] = {
97
+ GeminiModel.FLASH_8B: 30,
98
+ GeminiModel.FLASH: 15,
99
+ GeminiModel.PRO: 500,
100
+ GeminiModel.EXP: 0,
101
+ }
102
+
103
+ response: R
104
+ id: str
105
+ flow_name: str
106
+ task_name: str
107
+ model: GeminiModel
108
+ system: SystemMessage | None = None
109
+ input_messages: list[UserMessage | AssistantMessage]
110
+ input_tokens: int
111
+ output_tokens: int
112
+ start_time: datetime
113
+ end_time: datetime
114
+
115
+ @computed_field
116
+ @property
117
+ def duration_ms(self) -> int:
118
+ return int((self.end_time - self.start_time).total_seconds() * 1000)
119
+
120
+ @computed_field
121
+ @property
122
+ def total_cost(self) -> float:
123
+ input_cost = self.INPUT_CENTS_PER_MILLION_TOKENS[self.model] * self.input_tokens
124
+ output_cost = (
125
+ self.OUTPUT_CENTS_PER_MILLION_TOKENS[self.model] * self.output_tokens
126
+ )
127
+ return input_cost + output_cost
@@ -0,0 +1,62 @@
1
+ [tool.poetry]
2
+ name = "goose-py"
3
+ version = "0.1.0"
4
+ description = "A tool for AI workflows based on human-computer collaboration and structured output."
5
+ authors = [
6
+ "Nash Taylor <nash@chelle.ai>",
7
+ "Joshua Cook <joshua@chelle.ai>",
8
+ "Michael Sankur <michael@chelle.ai>"
9
+ ]
10
+ readme = "README.md"
11
+ homepage = "https://github.com/chelle-ai/goose"
12
+ repository = "https://github.com/chelle-ai/goose"
13
+ documentation = "https://github.com/chelle-ai/goose"
14
+ keywords = ["ai", "yaml", "configuration", "llm"]
15
+ classifiers = [
16
+ "Development Status :: 4 - Beta",
17
+ "Intended Audience :: Developers",
18
+ "Operating System :: OS Independent",
19
+ "Programming Language :: Python :: 3",
20
+ "Programming Language :: Python :: 3.12",
21
+ ]
22
+ packages = [
23
+ { include = "goose" }
24
+ ]
25
+
26
+
27
+ [tool.poetry.dependencies]
28
+ python = "^3.12"
29
+
30
+ jsonpath-ng = "^1.7.0"
31
+ litellm = "^1.56.5"
32
+ pydantic = "^2.8.2"
33
+ ipykernel = { version = "*", optional = true }
34
+ pytest = { version = "<8", optional = true }
35
+ pytest-mock = { version = "*", optional = true }
36
+ pytest-asyncio = { version = "*", optional = true }
37
+
38
+
39
+ [tool.poetry.extras]
40
+ test = ["ipykernel", "pytest", "pytest-mock", "pytest-asyncio"]
41
+
42
+
43
+ [build-system]
44
+ requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0"]
45
+ build-backend = "poetry_dynamic_versioning.backend"
46
+
47
+ [tool.pyright]
48
+ pythonVersion = "3.12"
49
+ typeCheckingMode = "strict"
50
+ reportMissingModuleSource = false
51
+ useLibraryCodeForTypes = false
52
+ reportImportCycles = true
53
+ reportUnknownMemberType = false
54
+ reportUnknownVariableType = false
55
+ stubPath = ".stubs"
56
+
57
+ [tool.pytest.ini_options]
58
+ filterwarnings = [
59
+ "ignore::DeprecationWarning",
60
+ "ignore::SyntaxWarning",
61
+ "ignore::UserWarning",
62
+ ]