goose-py 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
goose/__init__.py CHANGED
@@ -1,3 +0,0 @@
1
- from goose.flow import ConversationState, FlowState, Result, flow, task
2
-
3
- __all__ = ["ConversationState", "FlowState", "Result", "flow", "task"]
goose/agent.py CHANGED
@@ -88,10 +88,13 @@ class AssistantMessage(BaseModel):
88
88
 
89
89
 
90
90
  class SystemMessage(BaseModel):
91
- text: str
91
+ parts: list[TextMessagePart | MediaMessagePart]
92
92
 
93
93
  def render(self) -> LLMMessage:
94
- return {"role": "system", "content": [{"type": "text", "text": self.text}]}
94
+ return {
95
+ "role": "system",
96
+ "content": [part.render() for part in self.parts],
97
+ }
95
98
 
96
99
 
97
100
  class AgentResponse[R: BaseModel](BaseModel):
@@ -109,7 +112,7 @@ class AgentResponse[R: BaseModel](BaseModel):
109
112
  }
110
113
 
111
114
  response: R
112
- id: str
115
+ run_id: str
113
116
  flow_name: str
114
117
  task_name: str
115
118
  model: GeminiModel
@@ -186,7 +189,7 @@ class Agent:
186
189
  end_time = datetime.now()
187
190
  agent_response = AgentResponse(
188
191
  response=parsed_response,
189
- id=self.run_id,
192
+ run_id=self.run_id,
190
193
  flow_name=self.flow_name,
191
194
  task_name=task_name,
192
195
  model=model,
goose/flow.py CHANGED
@@ -14,10 +14,10 @@ from typing import (
14
14
 
15
15
  from pydantic import BaseModel, ConfigDict, field_validator
16
16
 
17
- from goose.agent import UserMessage
17
+ from goose.agent import Agent, AssistantMessage, LLMMessage, SystemMessage, UserMessage
18
18
  from goose.errors import Honk
19
19
 
20
- SerializedFlowState = NewType("SerializedFlowState", str)
20
+ SerializedFlowRun = NewType("SerializedFlowRun", str)
21
21
 
22
22
 
23
23
  class Result(BaseModel):
@@ -28,8 +28,9 @@ class GooseResponse[R: Result](BaseModel):
28
28
  result: R
29
29
 
30
30
 
31
- class ConversationState[R: Result](BaseModel):
31
+ class Conversation[R: Result](BaseModel):
32
32
  messages: list[UserMessage | GooseResponse[R]]
33
+ context: SystemMessage | None = None
33
34
 
34
35
  @field_validator("messages")
35
36
  def alternates_starting_with_result(
@@ -56,28 +57,44 @@ class ConversationState[R: Result](BaseModel):
56
57
  def awaiting_response(self) -> bool:
57
58
  return len(self.messages) % 2 == 0
58
59
 
60
+ def render(self) -> list[LLMMessage]:
61
+ messages: list[LLMMessage] = []
62
+ if self.context is not None:
63
+ messages.append(self.context.render())
64
+
65
+ for message in self.messages:
66
+ if isinstance(message, UserMessage):
67
+ messages.append(message.render())
68
+ else:
69
+ messages.append(
70
+ AssistantMessage(text=message.result.model_dump_json()).render()
71
+ )
72
+
73
+ return messages
74
+
59
75
 
60
76
  class IAdapter[ResultT: Result](Protocol):
61
- async def __call__(
62
- self, *, conversation_state: ConversationState[ResultT]
63
- ) -> ResultT: ...
77
+ async def __call__(self, *, conversation: Conversation[ResultT]) -> ResultT: ...
64
78
 
65
79
 
66
80
  class NodeState[ResultT: Result](BaseModel):
67
81
  task_name: str
68
82
  index: int
69
- conversation_state: ConversationState[ResultT]
83
+ conversation: Conversation[ResultT]
70
84
  last_input_hash: int
71
- pinned: bool
72
85
 
73
86
  @property
74
87
  def result(self) -> ResultT:
75
- last_message = self.conversation_state.messages[-1]
88
+ last_message = self.conversation.messages[-1]
76
89
  if isinstance(last_message, GooseResponse):
77
90
  return last_message.result
78
91
  else:
79
92
  raise Honk("Node awaiting response, has no result")
80
93
 
94
+ def set_context(self, *, context: SystemMessage) -> Self:
95
+ self.conversation.context = context
96
+ return self
97
+
81
98
  def add_result(
82
99
  self,
83
100
  *,
@@ -86,33 +103,42 @@ class NodeState[ResultT: Result](BaseModel):
86
103
  overwrite: bool = False,
87
104
  ) -> Self:
88
105
  if overwrite:
89
- if len(self.conversation_state.messages) == 0:
90
- self.conversation_state.messages.append(GooseResponse(result=result))
106
+ if len(self.conversation.messages) == 0:
107
+ self.conversation.messages.append(GooseResponse(result=result))
91
108
  else:
92
- self.conversation_state.messages[-1] = GooseResponse(result=result)
109
+ self.conversation.messages[-1] = GooseResponse(result=result)
93
110
  else:
94
- self.conversation_state.messages.append(GooseResponse(result=result))
111
+ self.conversation.messages.append(GooseResponse(result=result))
95
112
  if new_input_hash is not None:
96
113
  self.last_input_hash = new_input_hash
97
114
  return self
98
115
 
99
116
  def add_user_message(self, *, message: UserMessage) -> Self:
100
- self.conversation_state.messages.append(message)
101
- return self
102
-
103
- def pin(self) -> Self:
104
- self.pinned = True
105
- return self
106
-
107
- def unpin(self) -> Self:
108
- self.pinned = False
117
+ self.conversation.messages.append(message)
109
118
  return self
110
119
 
111
120
 
112
- class FlowState:
121
+ class FlowRun:
113
122
  def __init__(self) -> None:
114
123
  self._node_states: dict[tuple[str, int], str] = {}
115
124
  self._last_requested_indices: dict[str, int] = {}
125
+ self._flow_name = ""
126
+ self._id = ""
127
+ self._agent: Agent | None = None
128
+
129
+ @property
130
+ def flow_name(self) -> str:
131
+ return self._flow_name
132
+
133
+ @property
134
+ def id(self) -> str:
135
+ return self._id
136
+
137
+ @property
138
+ def agent(self) -> Agent:
139
+ if self._agent is None:
140
+ raise Honk("Agent is only accessible once a run is started")
141
+ return self._agent
116
142
 
117
143
  def add(self, node_state: NodeState[Any], /) -> None:
118
144
  key = (node_state.task_name, node_state.index)
@@ -144,19 +170,24 @@ class FlowState:
144
170
  return NodeState[task.result_type](
145
171
  task_name=task.name,
146
172
  index=index or 0,
147
- conversation_state=ConversationState[task.result_type](messages=[]),
173
+ conversation=Conversation[task.result_type](messages=[]),
148
174
  last_input_hash=0,
149
- pinned=False,
150
175
  )
151
176
 
152
- @contextmanager
153
- def run(self) -> Iterator[Self]:
177
+ def start(self, *, flow_name: str, run_id: str) -> None:
154
178
  self._last_requested_indices = {}
155
- yield self
179
+ self._flow_name = flow_name
180
+ self._id = run_id
181
+ self._agent = Agent(flow_name=self.flow_name, run_id=self.id)
182
+
183
+ def end(self) -> None:
156
184
  self._last_requested_indices = {}
185
+ self._flow_name = ""
186
+ self._id = ""
187
+ self._agent = None
157
188
 
158
- def dump(self) -> SerializedFlowState:
159
- return SerializedFlowState(
189
+ def dump(self) -> SerializedFlowRun:
190
+ return SerializedFlowRun(
160
191
  json.dumps(
161
192
  {
162
193
  ":".join([task_name, str(index)]): value
@@ -166,20 +197,20 @@ class FlowState:
166
197
  )
167
198
 
168
199
  @classmethod
169
- def load(cls, state: SerializedFlowState) -> Self:
170
- flow_state = cls()
171
- raw_node_states = json.loads(state)
200
+ def load(cls, run: SerializedFlowRun, /) -> Self:
201
+ flow_run = cls()
202
+ raw_node_states = json.loads(run)
172
203
  new_node_states: dict[tuple[str, int], str] = {}
173
204
  for key, node_state in raw_node_states.items():
174
205
  task_name, index = tuple(key.split(":"))
175
206
  new_node_states[(task_name, int(index))] = node_state
176
207
 
177
- flow_state._node_states = new_node_states
178
- return flow_state
208
+ flow_run._node_states = new_node_states
209
+ return flow_run
179
210
 
180
211
 
181
- _current_flow_state: ContextVar[FlowState | None] = ContextVar(
182
- "current_flow_state", default=None
212
+ _current_flow_run: ContextVar[FlowRun | None] = ContextVar(
213
+ "current_flow_run", default=None
183
214
  )
184
215
 
185
216
 
@@ -195,25 +226,32 @@ class Flow[**P]:
195
226
  return self._name or self._fn.__name__
196
227
 
197
228
  @property
198
- def state(self) -> FlowState:
199
- state = _current_flow_state.get()
200
- if state is None:
201
- raise Honk("No current flow state")
202
- return state
229
+ def current_run(self) -> FlowRun:
230
+ run = _current_flow_run.get()
231
+ if run is None:
232
+ raise Honk("No current flow run")
233
+ return run
203
234
 
204
235
  @contextmanager
205
- def run(self, *, state: FlowState | None = None) -> Iterator[FlowState]:
206
- if state is None:
207
- state = FlowState()
236
+ def start_run(
237
+ self, *, run_id: str, preload: FlowRun | None = None
238
+ ) -> Iterator[FlowRun]:
239
+ if preload is None:
240
+ run = FlowRun()
241
+ else:
242
+ run = preload
243
+
244
+ old_run = _current_flow_run.get()
245
+ _current_flow_run.set(run)
246
+
247
+ run.start(flow_name=self.name, run_id=run_id)
248
+ yield run
249
+ run.end()
208
250
 
209
- old_state = _current_flow_state.get()
210
- _current_flow_state.set(state)
211
- yield state
212
- _current_flow_state.set(old_state)
251
+ _current_flow_run.set(old_run)
213
252
 
214
253
  async def generate(self, *args: P.args, **kwargs: P.kwargs) -> None:
215
- with self.state.run():
216
- await self._fn(*args, **kwargs)
254
+ await self._fn(*args, **kwargs)
217
255
 
218
256
 
219
257
  class Task[**P, R: Result]:
@@ -252,31 +290,39 @@ class Task[**P, R: Result]:
252
290
  state.add_result(result=result, new_input_hash=input_hash, overwrite=True)
253
291
  return result
254
292
  else:
255
- if not isinstance(state.conversation_state.messages[-1], GooseResponse):
293
+ if not isinstance(state.conversation.messages[-1], GooseResponse):
256
294
  raise Honk(
257
295
  "Conversation must alternate between User and Result messages"
258
296
  )
259
297
  return state.result
260
298
 
261
- async def adapt(
262
- self, *, flow_state: FlowState, user_message: UserMessage, index: int = 0
299
+ async def jam(
300
+ self,
301
+ *,
302
+ user_message: UserMessage,
303
+ context: SystemMessage | None = None,
304
+ index: int = 0,
263
305
  ) -> R:
264
- node_state = flow_state.get(task=self, index=index)
306
+ flow_run = self.__get_current_flow_run()
307
+ node_state = flow_run.get(task=self, index=index)
265
308
  if self._adapter is None:
266
309
  raise Honk("No adapter provided for Task")
267
310
 
311
+ if context is not None:
312
+ node_state.set_context(context=context)
268
313
  node_state.add_user_message(message=user_message)
269
- result = await self._adapter(conversation_state=node_state.conversation_state)
314
+
315
+ result = await self._adapter(conversation=node_state.conversation)
270
316
  node_state.add_result(result=result)
271
- flow_state.add(node_state)
317
+ flow_run.add(node_state)
272
318
 
273
319
  return result
274
320
 
275
321
  async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
276
- flow_state = self.__get_current_flow_state()
277
- node_state = flow_state.get_next(task=self)
322
+ flow_run = self.__get_current_flow_run()
323
+ node_state = flow_run.get_next(task=self)
278
324
  result = await self.generate(node_state, *args, **kwargs)
279
- flow_state.add(node_state)
325
+ flow_run.add(node_state)
280
326
  return result
281
327
 
282
328
  def __hash_input(self, *args: P.args, **kwargs: P.kwargs) -> int:
@@ -286,11 +332,11 @@ class Task[**P, R: Result]:
286
332
  except TypeError:
287
333
  raise Honk(f"Unhashable argument to task {self.name}: {args} {kwargs}")
288
334
 
289
- def __get_current_flow_state(self) -> FlowState:
290
- state = _current_flow_state.get()
291
- if state is None:
292
- raise Honk("No current flow state")
293
- return state
335
+ def __get_current_flow_run(self) -> FlowRun:
336
+ run = _current_flow_run.get()
337
+ if run is None:
338
+ raise Honk("No current flow run")
339
+ return run
294
340
 
295
341
 
296
342
  @overload
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: goose-py
3
- Version: 0.2.2
3
+ Version: 0.3.1
4
4
  Summary: A tool for AI workflows based on human-computer collaboration and structured output.
5
5
  Home-page: https://github.com/chelle-ai/goose
6
6
  Keywords: ai,yaml,configuration,llm
@@ -0,0 +1,8 @@
1
+ goose/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ goose/agent.py,sha256=cJO6Jlh8En_FGLzjelm0w1pM4rsVqfIOpfKUdpkord4,5650
3
+ goose/errors.py,sha256=-0OyZQJWYTRw5YgnCB2_uorVaUsL6Z0QYQO2FqzCiyg,32
4
+ goose/flow.py,sha256=_ws3Jrp96E5UG7D0dgT0_c1RDs07H-hgJ0mSAmJf_yE,11735
5
+ goose/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ goose_py-0.3.1.dist-info/METADATA,sha256=wOP3QDcyVUJg-vjBw2KBb1xWdw_NA-50wgJtwg24WQc,1106
7
+ goose_py-0.3.1.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
8
+ goose_py-0.3.1.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- goose/__init__.py,sha256=wUKxLPSbhPl5Vv4HRK3wuWwanUtpgbe8toT31cM0ZLU,144
2
- goose/agent.py,sha256=5oUoVqvG_qTnsSQyzGjEbhnyk5mswQaoHTXF5SaIrg8,5568
3
- goose/errors.py,sha256=-0OyZQJWYTRw5YgnCB2_uorVaUsL6Z0QYQO2FqzCiyg,32
4
- goose/flow.py,sha256=K29ugPXLuGrsHBVry3WH94FE6_6Epg30-04V5Q69JPo,10482
5
- goose/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
- goose_py-0.2.2.dist-info/METADATA,sha256=BnE7WYuV1AlWofi0ejabJMSTp10GLR6xpDZk0AgFo9A,1106
7
- goose_py-0.2.2.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
8
- goose_py-0.2.2.dist-info/RECORD,,