goose-py 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
goose/__init__.py CHANGED
@@ -1,3 +0,0 @@
1
- from goose.flow import ConversationState, FlowState, Result, flow, task
2
-
3
- __all__ = ["ConversationState", "FlowState", "Result", "flow", "task"]
goose/agent.py CHANGED
@@ -88,10 +88,13 @@ class AssistantMessage(BaseModel):
88
88
 
89
89
 
90
90
  class SystemMessage(BaseModel):
91
- text: str
91
+ parts: list[TextMessagePart | MediaMessagePart]
92
92
 
93
93
  def render(self) -> LLMMessage:
94
- return {"role": "system", "content": [{"type": "text", "text": self.text}]}
94
+ return {
95
+ "role": "system",
96
+ "content": [part.render() for part in self.parts],
97
+ }
95
98
 
96
99
 
97
100
  class AgentResponse[R: BaseModel](BaseModel):
@@ -109,7 +112,7 @@ class AgentResponse[R: BaseModel](BaseModel):
109
112
  }
110
113
 
111
114
  response: R
112
- id: str
115
+ run_name: str
113
116
  flow_name: str
114
117
  task_name: str
115
118
  model: GeminiModel
@@ -125,14 +128,20 @@ class AgentResponse[R: BaseModel](BaseModel):
125
128
  def duration_ms(self) -> int:
126
129
  return int((self.end_time - self.start_time).total_seconds() * 1000)
127
130
 
131
+ @computed_field
132
+ @property
133
+ def input_cost(self) -> float:
134
+ return self.INPUT_CENTS_PER_MILLION_TOKENS[self.model] * self.input_tokens
135
+
136
+ @computed_field
137
+ @property
138
+ def output_cost(self) -> float:
139
+ return self.OUTPUT_CENTS_PER_MILLION_TOKENS[self.model] * self.output_tokens
140
+
128
141
  @computed_field
129
142
  @property
130
143
  def total_cost(self) -> float:
131
- input_cost = self.INPUT_CENTS_PER_MILLION_TOKENS[self.model] * self.input_tokens
132
- output_cost = (
133
- self.OUTPUT_CENTS_PER_MILLION_TOKENS[self.model] * self.output_tokens
134
- )
135
- return input_cost + output_cost
144
+ return self.input_cost + self.output_cost
136
145
 
137
146
 
138
147
  class Agent:
@@ -140,11 +149,11 @@ class Agent:
140
149
  self,
141
150
  *,
142
151
  flow_name: str,
143
- run_id: str,
152
+ run_name: str,
144
153
  logger: Callable[[AgentResponse[Any]], Awaitable[None]] | None = None,
145
154
  ) -> None:
146
155
  self.flow_name = flow_name
147
- self.run_id = run_id
156
+ self.run_name = run_name
148
157
  self.logger = logger
149
158
 
150
159
  async def __call__[R: BaseModel](
@@ -180,7 +189,7 @@ class Agent:
180
189
  end_time = datetime.now()
181
190
  agent_response = AgentResponse(
182
191
  response=parsed_response,
183
- id=self.run_id,
192
+ run_name=self.run_name,
184
193
  flow_name=self.flow_name,
185
194
  task_name=task_name,
186
195
  model=model,
goose/flow.py CHANGED
@@ -14,10 +14,10 @@ from typing import (
14
14
 
15
15
  from pydantic import BaseModel, ConfigDict, field_validator
16
16
 
17
- from goose.agent import UserMessage
17
+ from goose.agent import Agent, AssistantMessage, LLMMessage, SystemMessage, UserMessage
18
18
  from goose.errors import Honk
19
19
 
20
- SerializedFlowState = NewType("SerializedFlowState", str)
20
+ SerializedFlowRun = NewType("SerializedFlowRun", str)
21
21
 
22
22
 
23
23
  class Result(BaseModel):
@@ -28,8 +28,9 @@ class GooseResponse[R: Result](BaseModel):
28
28
  result: R
29
29
 
30
30
 
31
- class ConversationState[R: Result](BaseModel):
31
+ class Conversation[R: Result](BaseModel):
32
32
  messages: list[UserMessage | GooseResponse[R]]
33
+ context: SystemMessage | None = None
33
34
 
34
35
  @field_validator("messages")
35
36
  def alternates_starting_with_result(
@@ -56,28 +57,44 @@ class ConversationState[R: Result](BaseModel):
56
57
  def awaiting_response(self) -> bool:
57
58
  return len(self.messages) % 2 == 0
58
59
 
60
+ def render(self) -> list[LLMMessage]:
61
+ messages: list[LLMMessage] = []
62
+ if self.context is not None:
63
+ messages.append(self.context.render())
64
+
65
+ for message in self.messages:
66
+ if isinstance(message, UserMessage):
67
+ messages.append(message.render())
68
+ else:
69
+ messages.append(
70
+ AssistantMessage(text=message.result.model_dump_json()).render()
71
+ )
72
+
73
+ return messages
74
+
59
75
 
60
76
  class IAdapter[ResultT: Result](Protocol):
61
- async def __call__(
62
- self, *, conversation_state: ConversationState[ResultT]
63
- ) -> ResultT: ...
77
+ async def __call__(self, *, conversation: Conversation[ResultT]) -> ResultT: ...
64
78
 
65
79
 
66
80
  class NodeState[ResultT: Result](BaseModel):
67
81
  task_name: str
68
82
  index: int
69
- conversation_state: ConversationState[ResultT]
83
+ conversation: Conversation[ResultT]
70
84
  last_input_hash: int
71
- pinned: bool
72
85
 
73
86
  @property
74
87
  def result(self) -> ResultT:
75
- last_message = self.conversation_state.messages[-1]
88
+ last_message = self.conversation.messages[-1]
76
89
  if isinstance(last_message, GooseResponse):
77
90
  return last_message.result
78
91
  else:
79
92
  raise Honk("Node awaiting response, has no result")
80
93
 
94
+ def set_context(self, *, context: SystemMessage) -> Self:
95
+ self.conversation.context = context
96
+ return self
97
+
81
98
  def add_result(
82
99
  self,
83
100
  *,
@@ -86,33 +103,37 @@ class NodeState[ResultT: Result](BaseModel):
86
103
  overwrite: bool = False,
87
104
  ) -> Self:
88
105
  if overwrite:
89
- if len(self.conversation_state.messages) == 0:
90
- self.conversation_state.messages.append(GooseResponse(result=result))
106
+ if len(self.conversation.messages) == 0:
107
+ self.conversation.messages.append(GooseResponse(result=result))
91
108
  else:
92
- self.conversation_state.messages[-1] = GooseResponse(result=result)
109
+ self.conversation.messages[-1] = GooseResponse(result=result)
93
110
  else:
94
- self.conversation_state.messages.append(GooseResponse(result=result))
111
+ self.conversation.messages.append(GooseResponse(result=result))
95
112
  if new_input_hash is not None:
96
113
  self.last_input_hash = new_input_hash
97
114
  return self
98
115
 
99
116
  def add_user_message(self, *, message: UserMessage) -> Self:
100
- self.conversation_state.messages.append(message)
101
- return self
102
-
103
- def pin(self) -> Self:
104
- self.pinned = True
105
- return self
106
-
107
- def unpin(self) -> Self:
108
- self.pinned = False
117
+ self.conversation.messages.append(message)
109
118
  return self
110
119
 
111
120
 
112
- class FlowState:
121
+ class FlowRun:
113
122
  def __init__(self) -> None:
114
123
  self._node_states: dict[tuple[str, int], str] = {}
115
124
  self._last_requested_indices: dict[str, int] = {}
125
+ self._name = ""
126
+ self._agent: Agent | None = None
127
+
128
+ @property
129
+ def name(self) -> str:
130
+ return self._name
131
+
132
+ @property
133
+ def agent(self) -> Agent:
134
+ if self._agent is None:
135
+ raise Honk("Agent is only accessible once a run is started")
136
+ return self._agent
116
137
 
117
138
  def add(self, node_state: NodeState[Any], /) -> None:
118
139
  key = (node_state.task_name, node_state.index)
@@ -144,19 +165,22 @@ class FlowState:
144
165
  return NodeState[task.result_type](
145
166
  task_name=task.name,
146
167
  index=index or 0,
147
- conversation_state=ConversationState[task.result_type](messages=[]),
168
+ conversation=Conversation[task.result_type](messages=[]),
148
169
  last_input_hash=0,
149
- pinned=False,
150
170
  )
151
171
 
152
- @contextmanager
153
- def run(self) -> Iterator[Self]:
172
+ def start(self, *, name: str) -> None:
154
173
  self._last_requested_indices = {}
155
- yield self
174
+ self._name = name
175
+ self._agent = Agent(flow_name=self.name, run_name=name)
176
+
177
+ def end(self) -> None:
156
178
  self._last_requested_indices = {}
179
+ self._name = ""
180
+ self._agent = None
157
181
 
158
- def dump(self) -> SerializedFlowState:
159
- return SerializedFlowState(
182
+ def dump(self) -> SerializedFlowRun:
183
+ return SerializedFlowRun(
160
184
  json.dumps(
161
185
  {
162
186
  ":".join([task_name, str(index)]): value
@@ -166,20 +190,20 @@ class FlowState:
166
190
  )
167
191
 
168
192
  @classmethod
169
- def load(cls, state: SerializedFlowState) -> Self:
170
- flow_state = cls()
171
- raw_node_states = json.loads(state)
193
+ def load(cls, run: SerializedFlowRun, /) -> Self:
194
+ flow_run = cls()
195
+ raw_node_states = json.loads(run)
172
196
  new_node_states: dict[tuple[str, int], str] = {}
173
197
  for key, node_state in raw_node_states.items():
174
198
  task_name, index = tuple(key.split(":"))
175
199
  new_node_states[(task_name, int(index))] = node_state
176
200
 
177
- flow_state._node_states = new_node_states
178
- return flow_state
201
+ flow_run._node_states = new_node_states
202
+ return flow_run
179
203
 
180
204
 
181
- _current_flow_state: ContextVar[FlowState | None] = ContextVar(
182
- "current_flow_state", default=None
205
+ _current_flow_run: ContextVar[FlowRun | None] = ContextVar(
206
+ "current_flow_run", default=None
183
207
  )
184
208
 
185
209
 
@@ -195,25 +219,28 @@ class Flow[**P]:
195
219
  return self._name or self._fn.__name__
196
220
 
197
221
  @property
198
- def state(self) -> FlowState:
199
- state = _current_flow_state.get()
200
- if state is None:
201
- raise Honk("No current flow state")
202
- return state
222
+ def current_run(self) -> FlowRun:
223
+ run = _current_flow_run.get()
224
+ if run is None:
225
+ raise Honk("No current flow run")
226
+ return run
203
227
 
204
228
  @contextmanager
205
- def run(self, *, state: FlowState | None = None) -> Iterator[FlowState]:
206
- if state is None:
207
- state = FlowState()
229
+ def start_run(self, *, name: str, run: FlowRun | None = None) -> Iterator[FlowRun]:
230
+ if run is None:
231
+ run = FlowRun()
232
+
233
+ old_run = _current_flow_run.get()
234
+ _current_flow_run.set(run)
235
+
236
+ run.start(name=name)
237
+ yield run
238
+ run.end()
208
239
 
209
- old_state = _current_flow_state.get()
210
- _current_flow_state.set(state)
211
- yield state
212
- _current_flow_state.set(old_state)
240
+ _current_flow_run.set(old_run)
213
241
 
214
242
  async def generate(self, *args: P.args, **kwargs: P.kwargs) -> None:
215
- with self.state.run():
216
- await self._fn(*args, **kwargs)
243
+ await self._fn(*args, **kwargs)
217
244
 
218
245
 
219
246
  class Task[**P, R: Result]:
@@ -252,31 +279,39 @@ class Task[**P, R: Result]:
252
279
  state.add_result(result=result, new_input_hash=input_hash, overwrite=True)
253
280
  return result
254
281
  else:
255
- if not isinstance(state.conversation_state.messages[-1], GooseResponse):
282
+ if not isinstance(state.conversation.messages[-1], GooseResponse):
256
283
  raise Honk(
257
284
  "Conversation must alternate between User and Result messages"
258
285
  )
259
286
  return state.result
260
287
 
261
- async def adapt(
262
- self, *, flow_state: FlowState, user_message: UserMessage, index: int = 0
288
+ async def jam(
289
+ self,
290
+ *,
291
+ user_message: UserMessage,
292
+ context: SystemMessage | None = None,
293
+ index: int = 0,
263
294
  ) -> R:
264
- node_state = flow_state.get(task=self, index=index)
295
+ flow_run = self.__get_current_flow_run()
296
+ node_state = flow_run.get(task=self, index=index)
265
297
  if self._adapter is None:
266
298
  raise Honk("No adapter provided for Task")
267
299
 
300
+ if context is not None:
301
+ node_state.set_context(context=context)
268
302
  node_state.add_user_message(message=user_message)
269
- result = await self._adapter(conversation_state=node_state.conversation_state)
303
+
304
+ result = await self._adapter(conversation=node_state.conversation)
270
305
  node_state.add_result(result=result)
271
- flow_state.add(node_state)
306
+ flow_run.add(node_state)
272
307
 
273
308
  return result
274
309
 
275
310
  async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
276
- flow_state = self.__get_current_flow_state()
277
- node_state = flow_state.get_next(task=self)
311
+ flow_run = self.__get_current_flow_run()
312
+ node_state = flow_run.get_next(task=self)
278
313
  result = await self.generate(node_state, *args, **kwargs)
279
- flow_state.add(node_state)
314
+ flow_run.add(node_state)
280
315
  return result
281
316
 
282
317
  def __hash_input(self, *args: P.args, **kwargs: P.kwargs) -> int:
@@ -286,11 +321,11 @@ class Task[**P, R: Result]:
286
321
  except TypeError:
287
322
  raise Honk(f"Unhashable argument to task {self.name}: {args} {kwargs}")
288
323
 
289
- def __get_current_flow_state(self) -> FlowState:
290
- state = _current_flow_state.get()
291
- if state is None:
292
- raise Honk("No current flow state")
293
- return state
324
+ def __get_current_flow_run(self) -> FlowRun:
325
+ run = _current_flow_run.get()
326
+ if run is None:
327
+ raise Honk("No current flow run")
328
+ return run
294
329
 
295
330
 
296
331
  @overload
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: goose-py
3
- Version: 0.2.1
3
+ Version: 0.3.0
4
4
  Summary: A tool for AI workflows based on human-computer collaboration and structured output.
5
5
  Home-page: https://github.com/chelle-ai/goose
6
6
  Keywords: ai,yaml,configuration,llm
@@ -0,0 +1,8 @@
1
+ goose/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ goose/agent.py,sha256=j8r9p7PqIdXCX_Vl8pUTrF2SshJXUDitgJlCV_C6hdw,5662
3
+ goose/errors.py,sha256=-0OyZQJWYTRw5YgnCB2_uorVaUsL6Z0QYQO2FqzCiyg,32
4
+ goose/flow.py,sha256=5a9lIHTyW5hLeMdt1AfIxe5z-VlTti_3OJ1o4f_nuoo,11458
5
+ goose/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ goose_py-0.3.0.dist-info/METADATA,sha256=dMNVLzZZ7KyEUacFGWqLeNJuv0ni5-GS_FP9ifaF4bM,1106
7
+ goose_py-0.3.0.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
8
+ goose_py-0.3.0.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- goose/__init__.py,sha256=wUKxLPSbhPl5Vv4HRK3wuWwanUtpgbe8toT31cM0ZLU,144
2
- goose/agent.py,sha256=ck_3uRdyhmgaygowy436CQYZIEFa6sA-7y5bzRNAWd8,5454
3
- goose/errors.py,sha256=-0OyZQJWYTRw5YgnCB2_uorVaUsL6Z0QYQO2FqzCiyg,32
4
- goose/flow.py,sha256=K29ugPXLuGrsHBVry3WH94FE6_6Epg30-04V5Q69JPo,10482
5
- goose/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
- goose_py-0.2.1.dist-info/METADATA,sha256=_BuAp1wz1Gmm-BX87W15FjYATHI5_YJUm_Ts9ROp-jM,1106
7
- goose_py-0.2.1.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
8
- goose_py-0.2.1.dist-info/RECORD,,