goose-py 0.3.12__tar.gz → 0.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: goose-py
3
- Version: 0.3.12
3
+ Version: 0.4.1
4
4
  Summary: A tool for AI workflows based on human-computer collaboration and structured output.
5
5
  Home-page: https://github.com/chelle-ai/goose
6
6
  Keywords: ai,yaml,configuration,llm
@@ -7,6 +7,7 @@ from typing import Any, ClassVar, Literal, NotRequired, Protocol, TypedDict
7
7
 
8
8
  from litellm import acompletion
9
9
  from pydantic import BaseModel, computed_field
10
+ from goose.result import Result, TextResult
10
11
 
11
12
 
12
13
  class GeminiModel(StrEnum):
@@ -115,7 +116,7 @@ class AgentResponseDump(TypedDict):
115
116
  duration_ms: int
116
117
 
117
118
 
118
- class AgentResponse[R: BaseModel](BaseModel):
119
+ class AgentResponse[R: BaseModel | str](BaseModel):
119
120
  INPUT_CENTS_PER_MILLION_TOKENS: ClassVar[dict[GeminiModel, float]] = {
120
121
  GeminiModel.FLASH_8B: 30,
121
122
  GeminiModel.FLASH: 15,
@@ -186,6 +187,12 @@ class AgentResponse[R: BaseModel](BaseModel):
186
187
  json.dumps(message) for message in minimized_input_messages
187
188
  ]
188
189
 
190
+ output_message = (
191
+ self.response.model_dump_json()
192
+ if isinstance(self.response, BaseModel)
193
+ else self.response
194
+ )
195
+
189
196
  return {
190
197
  "run_id": self.run_id,
191
198
  "flow_name": self.flow_name,
@@ -193,7 +200,7 @@ class AgentResponse[R: BaseModel](BaseModel):
193
200
  "model": self.model.value,
194
201
  "system_message": minimized_system_message,
195
202
  "input_messages": minimized_input_messages,
196
- "output_message": self.response.model_dump_json(),
203
+ "output_message": output_message,
197
204
  "input_tokens": self.input_tokens,
198
205
  "output_tokens": self.output_tokens,
199
206
  "input_cost": self.input_cost,
@@ -221,13 +228,13 @@ class Agent:
221
228
  self.run_id = run_id
222
229
  self.logger = logger
223
230
 
224
- async def __call__[R: BaseModel](
231
+ async def __call__[R: Result](
225
232
  self,
226
233
  *,
227
234
  messages: list[UserMessage | AssistantMessage],
228
235
  model: GeminiModel,
229
- response_model: type[R],
230
236
  task_name: str,
237
+ response_model: type[R] = TextResult,
231
238
  system: SystemMessage | None = None,
232
239
  ) -> R:
233
240
  start_time = datetime.now()
@@ -235,22 +242,25 @@ class Agent:
235
242
  if system is not None:
236
243
  rendered_messages.insert(0, system.render())
237
244
 
238
- response = await acompletion(
239
- model=model.value,
240
- messages=rendered_messages,
241
- response_format={
242
- "type": "json_object",
243
- "response_schema": response_model.model_json_schema(),
244
- "enforce_validation": True,
245
- },
246
- )
247
-
248
- if len(response.choices) == 0:
249
- raise RuntimeError("No content returned from LLM call.")
245
+ if response_model is TextResult:
246
+ response = await acompletion(model=model.value, messages=rendered_messages)
247
+ parsed_response = response_model.model_validate(
248
+ {"text": response.choices[0].message.content}
249
+ )
250
+ else:
251
+ response = await acompletion(
252
+ model=model.value,
253
+ messages=rendered_messages,
254
+ response_format={
255
+ "type": "json_object",
256
+ "response_schema": response_model.model_json_schema(),
257
+ "enforce_validation": True,
258
+ },
259
+ )
260
+ parsed_response = response_model.model_validate_json(
261
+ response.choices[0].message.content
262
+ )
250
263
 
251
- parsed_response = response_model.model_validate_json(
252
- response.choices[0].message.content
253
- )
254
264
  end_time = datetime.now()
255
265
  agent_response = AgentResponse(
256
266
  response=parsed_response,
@@ -271,4 +281,4 @@ class Agent:
271
281
  else:
272
282
  logging.info(agent_response.model_dump())
273
283
 
274
- return agent_response.response
284
+ return parsed_response
@@ -13,7 +13,7 @@ from typing import (
13
13
  overload,
14
14
  )
15
15
 
16
- from pydantic import BaseModel, ConfigDict
16
+ from pydantic import BaseModel
17
17
 
18
18
  from goose.agent import (
19
19
  Agent,
@@ -25,14 +25,11 @@ from goose.agent import (
25
25
  )
26
26
  from goose.errors import Honk
27
27
  from goose.store import IFlowRunStore, InMemoryFlowRunStore
28
+ from goose.result import Result
28
29
 
29
30
  SerializedFlowRun = NewType("SerializedFlowRun", str)
30
31
 
31
32
 
32
- class Result(BaseModel):
33
- model_config = ConfigDict(frozen=True)
34
-
35
-
36
33
  class Conversation[R: Result](BaseModel):
37
34
  user_messages: list[UserMessage]
38
35
  result_messages: list[R]
@@ -115,6 +112,8 @@ class FlowRun:
115
112
  self._flow_name = ""
116
113
  self._id = ""
117
114
  self._agent: Agent | None = None
115
+ self._flow_args: tuple[Any, ...] | None = None
116
+ self._flow_kwargs: dict[str, Any] | None = None
118
117
 
119
118
  @property
120
119
  def flow_name(self) -> str:
@@ -130,17 +129,12 @@ class FlowRun:
130
129
  raise Honk("Agent is only accessible once a run is started")
131
130
  return self._agent
132
131
 
133
- def add(self, node_state: NodeState[Any], /) -> None:
134
- key = (node_state.task_name, node_state.index)
135
- self._node_states[key] = node_state.model_dump_json()
136
-
137
- def get_next[R: Result](self, *, task: "Task[Any, R]") -> NodeState[R]:
138
- if task.name not in self._last_requested_indices:
139
- self._last_requested_indices[task.name] = 0
140
- else:
141
- self._last_requested_indices[task.name] += 1
132
+ @property
133
+ def flow_inputs(self) -> tuple[tuple[Any, ...], dict[str, Any]]:
134
+ if self._flow_args is None or self._flow_kwargs is None:
135
+ raise Honk("This Flow run has not been executed before")
142
136
 
143
- return self.get(task=task, index=self._last_requested_indices[task.name])
137
+ return self._flow_args, self._flow_kwargs
144
138
 
145
139
  def get_all[R: Result](self, *, task: "Task[Any, R]") -> list[NodeState[R]]:
146
140
  matching_nodes: list[NodeState[R]] = []
@@ -166,6 +160,22 @@ class FlowRun:
166
160
  last_hash=0,
167
161
  )
168
162
 
163
+ def set_flow_inputs(self, *args: Any, **kwargs: Any) -> None:
164
+ self._flow_args = args
165
+ self._flow_kwargs = kwargs
166
+
167
+ def add_node_state(self, node_state: NodeState[Any], /) -> None:
168
+ key = (node_state.task_name, node_state.index)
169
+ self._node_states[key] = node_state.model_dump_json()
170
+
171
+ def get_next[R: Result](self, *, task: "Task[Any, R]") -> NodeState[R]:
172
+ if task.name not in self._last_requested_indices:
173
+ self._last_requested_indices[task.name] = 0
174
+ else:
175
+ self._last_requested_indices[task.name] += 1
176
+
177
+ return self.get(task=task, index=self._last_requested_indices[task.name])
178
+
169
179
  def start(
170
180
  self,
171
181
  *,
@@ -192,25 +202,35 @@ class FlowRun:
192
202
  del self._node_states[key]
193
203
 
194
204
  def dump(self) -> SerializedFlowRun:
205
+ flow_args, flow_kwargs = self.flow_inputs
206
+
195
207
  return SerializedFlowRun(
196
208
  json.dumps(
197
209
  {
198
- ":".join([task_name, str(index)]): value
199
- for (task_name, index), value in self._node_states.items()
210
+ "node_states": {
211
+ ":".join([task_name, str(index)]): value
212
+ for (task_name, index), value in self._node_states.items()
213
+ },
214
+ "flow_args": list(flow_args),
215
+ "flow_kwargs": flow_kwargs,
200
216
  }
201
217
  )
202
218
  )
203
219
 
204
220
  @classmethod
205
- def load(cls, run: SerializedFlowRun, /) -> Self:
221
+ def load(cls, serialized_flow_run: SerializedFlowRun, /) -> Self:
206
222
  flow_run = cls()
207
- raw_node_states = json.loads(run)
223
+ run = json.loads(serialized_flow_run)
224
+
208
225
  new_node_states: dict[tuple[str, int], str] = {}
209
- for key, node_state in raw_node_states.items():
226
+ for key, node_state in run["node_states"].items():
210
227
  task_name, index = tuple(key.split(":"))
211
228
  new_node_states[(task_name, int(index))] = node_state
212
-
213
229
  flow_run._node_states = new_node_states
230
+
231
+ flow_run._flow_args = tuple(run["flow_args"])
232
+ flow_run._flow_kwargs = run["flow_kwargs"]
233
+
214
234
  return flow_run
215
235
 
216
236
 
@@ -264,8 +284,21 @@ class Flow[**P]:
264
284
  _current_flow_run.set(old_run)
265
285
 
266
286
  async def generate(self, *args: P.args, **kwargs: P.kwargs) -> None:
287
+ flow_run = _current_flow_run.get()
288
+ if flow_run is None:
289
+ raise Honk("No current flow run")
290
+
291
+ flow_run.set_flow_inputs(*args, **kwargs)
267
292
  await self._fn(*args, **kwargs)
268
293
 
294
+ async def regenerate(self) -> None:
295
+ flow_run = _current_flow_run.get()
296
+ if flow_run is None:
297
+ raise Honk("No current flow run")
298
+
299
+ flow_args, flow_kwargs = flow_run.flow_inputs
300
+ await self._fn(*flow_args, **flow_kwargs)
301
+
269
302
 
270
303
  class Task[**P, R: Result]:
271
304
  def __init__(
@@ -323,7 +356,7 @@ class Task[**P, R: Result]:
323
356
 
324
357
  result = await self._adapter(conversation=node_state.conversation)
325
358
  node_state.add_result(result=result)
326
- flow_run.add(node_state)
359
+ flow_run.add_node_state(node_state)
327
360
 
328
361
  return result
329
362
 
@@ -331,7 +364,7 @@ class Task[**P, R: Result]:
331
364
  flow_run = self.__get_current_flow_run()
332
365
  node_state = flow_run.get_next(task=self)
333
366
  result = await self.generate(node_state, *args, **kwargs)
334
- flow_run.add(node_state)
367
+ flow_run.add_node_state(node_state)
335
368
  return result
336
369
 
337
370
  def __hash_task_call(self, *args: P.args, **kwargs: P.kwargs) -> int:
@@ -0,0 +1,9 @@
1
+ from pydantic import BaseModel, ConfigDict
2
+
3
+
4
+ class Result(BaseModel):
5
+ model_config = ConfigDict(frozen=True)
6
+
7
+
8
+ class TextResult(Result):
9
+ text: str
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "goose-py"
3
- version = "0.3.12"
3
+ version = "0.4.1"
4
4
  description = "A tool for AI workflows based on human-computer collaboration and structured output."
5
5
  authors = [
6
6
  "Nash Taylor <nash@chelle.ai>",
File without changes
File without changes
File without changes
File without changes
File without changes