goose-py 0.9.16__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
goose/_internal/agent.py CHANGED
@@ -1,119 +1,15 @@
1
- import json
2
1
  import logging
3
2
  from datetime import datetime
4
- from typing import Any, ClassVar, Protocol, TypedDict
3
+ from typing import Any, Literal, Protocol, overload
5
4
 
6
5
  from litellm import acompletion
7
- from pydantic import BaseModel, computed_field
8
-
9
- from .result import Result, TextResult
10
- from .types.agent import AIModel, LLMMessage
6
+ from pydantic import ValidationError
11
7
 
8
+ from goose._internal.types.telemetry import AgentResponse
9
+ from goose.errors import Honk
12
10
 
13
- class AgentResponseDump(TypedDict):
14
- run_id: str
15
- flow_name: str
16
- task_name: str
17
- model: str
18
- system_message: str
19
- input_messages: list[str]
20
- output_message: str
21
- input_cost: float
22
- output_cost: float
23
- total_cost: float
24
- input_tokens: int
25
- output_tokens: int
26
- start_time: datetime
27
- end_time: datetime
28
- duration_ms: int
29
-
30
-
31
- class AgentResponse[R: BaseModel | str](BaseModel):
32
- INPUT_CENTS_PER_MILLION_TOKENS: ClassVar[dict[AIModel, float]] = {
33
- AIModel.VERTEX_FLASH_8B: 30,
34
- AIModel.VERTEX_FLASH: 15,
35
- AIModel.VERTEX_PRO: 500,
36
- AIModel.GEMINI_FLASH_8B: 30,
37
- AIModel.GEMINI_FLASH: 15,
38
- AIModel.GEMINI_PRO: 500,
39
- }
40
- OUTPUT_CENTS_PER_MILLION_TOKENS: ClassVar[dict[AIModel, float]] = {
41
- AIModel.VERTEX_FLASH_8B: 30,
42
- AIModel.VERTEX_FLASH: 15,
43
- AIModel.VERTEX_PRO: 500,
44
- AIModel.GEMINI_FLASH_8B: 30,
45
- AIModel.GEMINI_FLASH: 15,
46
- AIModel.GEMINI_PRO: 500,
47
- }
48
-
49
- response: R
50
- run_id: str
51
- flow_name: str
52
- task_name: str
53
- model: AIModel
54
- system: LLMMessage | None = None
55
- input_messages: list[LLMMessage]
56
- input_tokens: int
57
- output_tokens: int
58
- start_time: datetime
59
- end_time: datetime
60
-
61
- @computed_field
62
- @property
63
- def duration_ms(self) -> int:
64
- return int((self.end_time - self.start_time).total_seconds() * 1000)
65
-
66
- @computed_field
67
- @property
68
- def input_cost(self) -> float:
69
- return self.INPUT_CENTS_PER_MILLION_TOKENS[self.model] * self.input_tokens / 1_000_000
70
-
71
- @computed_field
72
- @property
73
- def output_cost(self) -> float:
74
- return self.OUTPUT_CENTS_PER_MILLION_TOKENS[self.model] * self.output_tokens / 1_000_000
75
-
76
- @computed_field
77
- @property
78
- def total_cost(self) -> float:
79
- return self.input_cost + self.output_cost
80
-
81
- def minimized_dump(self) -> AgentResponseDump:
82
- if self.system is None:
83
- minimized_system_message = ""
84
- else:
85
- minimized_system_message = self.system
86
- for part in minimized_system_message["content"]:
87
- if part["type"] == "image_url":
88
- part["image_url"] = "__MEDIA__"
89
- minimized_system_message = json.dumps(minimized_system_message)
90
-
91
- minimized_input_messages = [message for message in self.input_messages]
92
- for message in minimized_input_messages:
93
- for part in message["content"]:
94
- if part["type"] == "image_url":
95
- part["image_url"] = "__MEDIA__"
96
- minimized_input_messages = [json.dumps(message) for message in minimized_input_messages]
97
-
98
- output_message = self.response.model_dump_json() if isinstance(self.response, BaseModel) else self.response
99
-
100
- return {
101
- "run_id": self.run_id,
102
- "flow_name": self.flow_name,
103
- "task_name": self.task_name,
104
- "model": self.model.value,
105
- "system_message": minimized_system_message,
106
- "input_messages": minimized_input_messages,
107
- "output_message": output_message,
108
- "input_tokens": self.input_tokens,
109
- "output_tokens": self.output_tokens,
110
- "input_cost": self.input_cost,
111
- "output_cost": self.output_cost,
112
- "total_cost": self.total_cost,
113
- "start_time": self.start_time,
114
- "end_time": self.end_time,
115
- "duration_ms": self.duration_ms,
116
- }
11
+ from .result import FindReplaceResponse, Result, TextResult
12
+ from .types.agent import AIModel, LLMMessage
117
13
 
118
14
 
119
15
  class IAgentLogger(Protocol):
@@ -132,7 +28,7 @@ class Agent:
132
28
  self.run_id = run_id
133
29
  self.logger = logger
134
30
 
135
- async def __call__[R: Result](
31
+ async def generate[R: Result](
136
32
  self,
137
33
  *,
138
34
  messages: list[LLMMessage],
@@ -177,3 +73,173 @@ class Agent:
177
73
  logging.info(agent_response.model_dump())
178
74
 
179
75
  return parsed_response
76
+
77
+ async def ask(
78
+ self, *, messages: list[LLMMessage], model: AIModel, task_name: str, system: LLMMessage | None = None
79
+ ) -> str:
80
+ start_time = datetime.now()
81
+
82
+ if system is not None:
83
+ messages.insert(0, system)
84
+ response = await acompletion(model=model.value, messages=messages)
85
+
86
+ end_time = datetime.now()
87
+ agent_response = AgentResponse(
88
+ response=response.choices[0].message.content,
89
+ run_id=self.run_id,
90
+ flow_name=self.flow_name,
91
+ task_name=task_name,
92
+ model=model,
93
+ system=system,
94
+ input_messages=messages,
95
+ input_tokens=response.usage.prompt_tokens,
96
+ output_tokens=response.usage.completion_tokens,
97
+ start_time=start_time,
98
+ end_time=end_time,
99
+ )
100
+
101
+ if self.logger is not None:
102
+ await self.logger(response=agent_response)
103
+ else:
104
+ logging.info(agent_response.model_dump())
105
+
106
+ return response.choices[0].message.content
107
+
108
+ async def refine[R: Result](
109
+ self,
110
+ *,
111
+ messages: list[LLMMessage],
112
+ model: AIModel,
113
+ task_name: str,
114
+ response_model: type[R],
115
+ system: LLMMessage | None = None,
116
+ ) -> R:
117
+ start_time = datetime.now()
118
+
119
+ if system is not None:
120
+ messages.insert(0, system)
121
+
122
+ find_replace_response = await acompletion(
123
+ model=model.value, messages=messages, response_format=FindReplaceResponse
124
+ )
125
+ parsed_find_replace_response = FindReplaceResponse.model_validate_json(
126
+ find_replace_response.choices[0].message.content
127
+ )
128
+
129
+ end_time = datetime.now()
130
+ agent_response = AgentResponse(
131
+ response=parsed_find_replace_response,
132
+ run_id=self.run_id,
133
+ flow_name=self.flow_name,
134
+ task_name=task_name,
135
+ model=model,
136
+ system=system,
137
+ input_messages=messages,
138
+ input_tokens=find_replace_response.usage.prompt_tokens,
139
+ output_tokens=find_replace_response.usage.completion_tokens,
140
+ start_time=start_time,
141
+ end_time=end_time,
142
+ )
143
+
144
+ if self.logger is not None:
145
+ await self.logger(response=agent_response)
146
+ else:
147
+ logging.info(agent_response.model_dump())
148
+
149
+ refined_response = self.__apply_find_replace(
150
+ result=self.__find_last_result(messages=messages, response_model=response_model),
151
+ find_replace_response=parsed_find_replace_response,
152
+ response_model=response_model,
153
+ )
154
+
155
+ return refined_response
156
+
157
+ @overload
158
+ async def __call__[R: Result](
159
+ self,
160
+ *,
161
+ messages: list[LLMMessage],
162
+ model: AIModel,
163
+ task_name: str,
164
+ mode: Literal["generate"],
165
+ response_model: type[R],
166
+ system: LLMMessage | None = None,
167
+ ) -> R: ...
168
+
169
+ @overload
170
+ async def __call__[R: Result](
171
+ self,
172
+ *,
173
+ messages: list[LLMMessage],
174
+ model: AIModel,
175
+ task_name: str,
176
+ mode: Literal["ask"],
177
+ response_model: type[R] = TextResult,
178
+ system: LLMMessage | None = None,
179
+ ) -> str: ...
180
+
181
+ @overload
182
+ async def __call__[R: Result](
183
+ self,
184
+ *,
185
+ messages: list[LLMMessage],
186
+ model: AIModel,
187
+ task_name: str,
188
+ response_model: type[R],
189
+ mode: Literal["refine"],
190
+ system: LLMMessage | None = None,
191
+ ) -> R: ...
192
+
193
+ @overload
194
+ async def __call__[R: Result](
195
+ self,
196
+ *,
197
+ messages: list[LLMMessage],
198
+ model: AIModel,
199
+ task_name: str,
200
+ response_model: type[R],
201
+ system: LLMMessage | None = None,
202
+ ) -> R: ...
203
+
204
+ async def __call__[R: Result](
205
+ self,
206
+ *,
207
+ messages: list[LLMMessage],
208
+ model: AIModel,
209
+ task_name: str,
210
+ response_model: type[R] = TextResult,
211
+ mode: Literal["generate", "ask", "refine"] = "generate",
212
+ system: LLMMessage | None = None,
213
+ ) -> R | str:
214
+ match mode:
215
+ case "generate":
216
+ return await self.generate(
217
+ messages=messages, model=model, task_name=task_name, response_model=response_model, system=system
218
+ )
219
+ case "ask":
220
+ return await self.ask(messages=messages, model=model, task_name=task_name, system=system)
221
+ case "refine":
222
+ return await self.refine(
223
+ messages=messages, model=model, task_name=task_name, response_model=response_model, system=system
224
+ )
225
+
226
+ def __apply_find_replace[R: Result](
227
+ self, *, result: R, find_replace_response: FindReplaceResponse, response_model: type[R]
228
+ ) -> R:
229
+ dumped_result = result.model_dump_json()
230
+ for replacement in find_replace_response.replacements:
231
+ dumped_result = dumped_result.replace(replacement.find, replacement.replace)
232
+
233
+ return response_model.model_validate_json(dumped_result)
234
+
235
+ def __find_last_result[R: Result](self, *, messages: list[LLMMessage], response_model: type[R]) -> R:
236
+ for message in reversed(messages):
237
+ if message["role"] == "assistant":
238
+ try:
239
+ only_part = message["content"][0]
240
+ if only_part["type"] == "text":
241
+ return response_model.model_validate_json(only_part["text"])
242
+ except ValidationError:
243
+ continue
244
+
245
+ raise Honk("No last result found, failed to refine")
goose/_internal/result.py CHANGED
@@ -1,4 +1,4 @@
1
- from pydantic import BaseModel, ConfigDict
1
+ from pydantic import BaseModel, ConfigDict, Field
2
2
 
3
3
 
4
4
  class Result(BaseModel):
@@ -7,3 +7,14 @@ class Result(BaseModel):
7
7
 
8
8
  class TextResult(Result):
9
9
  text: str
10
+
11
+
12
+ class Replacement(BaseModel):
13
+ find: str = Field(description="Text to find, to be replaced with `replace`")
14
+ replace: str = Field(description="Text to replace `find` with")
15
+
16
+
17
+ class FindReplaceResponse(BaseModel):
18
+ replacements: list[Replacement] = Field(
19
+ description="List of replacements to make in the previous result to satisfy the user's request"
20
+ )
goose/_internal/task.py CHANGED
@@ -57,11 +57,12 @@ class Task[**P, R: Result]:
57
57
  model=self._refinement_model,
58
58
  task_name=f"ask--{self.name}",
59
59
  system=context.render() if context is not None else None,
60
+ mode="ask",
60
61
  )
61
- node_state.add_answer(answer=answer.text)
62
+ node_state.add_answer(answer=answer)
62
63
  flow_run.upsert_node_state(node_state)
63
64
 
64
- return answer.text
65
+ return answer
65
66
 
66
67
  async def refine(
67
68
  self,
@@ -86,6 +87,7 @@ class Task[**P, R: Result]:
86
87
  task_name=f"refine--{self.name}",
87
88
  system=context.render() if context is not None else None,
88
89
  response_model=self.result_type,
90
+ mode="refine",
89
91
  )
90
92
  node_state.add_result(result=result)
91
93
  flow_run.upsert_node_state(node_state)
@@ -0,0 +1,113 @@
1
+ import json
2
+ from datetime import datetime
3
+ from typing import ClassVar, TypedDict
4
+
5
+ from pydantic import BaseModel, computed_field
6
+
7
+ from ..types.agent import AIModel, LLMMessage
8
+
9
+
10
+ class AgentResponseDump(TypedDict):
11
+ run_id: str
12
+ flow_name: str
13
+ task_name: str
14
+ model: str
15
+ system_message: str
16
+ input_messages: list[str]
17
+ output_message: str
18
+ input_cost: float
19
+ output_cost: float
20
+ total_cost: float
21
+ input_tokens: int
22
+ output_tokens: int
23
+ start_time: datetime
24
+ end_time: datetime
25
+ duration_ms: int
26
+
27
+
28
+ class AgentResponse[R: BaseModel | str](BaseModel):
29
+ INPUT_DOLLARS_PER_MILLION_TOKENS: ClassVar[dict[AIModel, float]] = {
30
+ AIModel.VERTEX_FLASH_8B: 0.30,
31
+ AIModel.VERTEX_FLASH: 0.15,
32
+ AIModel.VERTEX_PRO: 5.00,
33
+ AIModel.GEMINI_FLASH_8B: 0.30,
34
+ AIModel.GEMINI_FLASH: 0.15,
35
+ AIModel.GEMINI_PRO: 5.00,
36
+ }
37
+ OUTPUT_DOLLARS_PER_MILLION_TOKENS: ClassVar[dict[AIModel, float]] = {
38
+ AIModel.VERTEX_FLASH_8B: 0.30,
39
+ AIModel.VERTEX_FLASH: 0.15,
40
+ AIModel.VERTEX_PRO: 5.00,
41
+ AIModel.GEMINI_FLASH_8B: 0.30,
42
+ AIModel.GEMINI_FLASH: 0.15,
43
+ AIModel.GEMINI_PRO: 5.00,
44
+ }
45
+
46
+ response: R
47
+ run_id: str
48
+ flow_name: str
49
+ task_name: str
50
+ model: AIModel
51
+ system: LLMMessage | None = None
52
+ input_messages: list[LLMMessage]
53
+ input_tokens: int
54
+ output_tokens: int
55
+ start_time: datetime
56
+ end_time: datetime
57
+
58
+ @computed_field
59
+ @property
60
+ def duration_ms(self) -> int:
61
+ return int((self.end_time - self.start_time).total_seconds() * 1000)
62
+
63
+ @computed_field
64
+ @property
65
+ def input_cost(self) -> float:
66
+ return self.INPUT_DOLLARS_PER_MILLION_TOKENS[self.model] * self.input_tokens / 1_000_000
67
+
68
+ @computed_field
69
+ @property
70
+ def output_cost(self) -> float:
71
+ return self.OUTPUT_DOLLARS_PER_MILLION_TOKENS[self.model] * self.output_tokens / 1_000_000
72
+
73
+ @computed_field
74
+ @property
75
+ def total_cost(self) -> float:
76
+ return self.input_cost + self.output_cost
77
+
78
+ def minimized_dump(self) -> AgentResponseDump:
79
+ if self.system is None:
80
+ minimized_system_message = ""
81
+ else:
82
+ minimized_system_message = self.system
83
+ for part in minimized_system_message["content"]:
84
+ if part["type"] == "image_url":
85
+ part["image_url"] = "__MEDIA__"
86
+ minimized_system_message = json.dumps(minimized_system_message)
87
+
88
+ minimized_input_messages = [message for message in self.input_messages]
89
+ for message in minimized_input_messages:
90
+ for part in message["content"]:
91
+ if part["type"] == "image_url":
92
+ part["image_url"] = "__MEDIA__"
93
+ minimized_input_messages = [json.dumps(message) for message in minimized_input_messages]
94
+
95
+ output_message = self.response.model_dump_json() if isinstance(self.response, BaseModel) else self.response
96
+
97
+ return {
98
+ "run_id": self.run_id,
99
+ "flow_name": self.flow_name,
100
+ "task_name": self.task_name,
101
+ "model": self.model.value,
102
+ "system_message": minimized_system_message,
103
+ "input_messages": minimized_input_messages,
104
+ "output_message": output_message,
105
+ "input_tokens": self.input_tokens,
106
+ "output_tokens": self.output_tokens,
107
+ "input_cost": self.input_cost,
108
+ "output_cost": self.output_cost,
109
+ "total_cost": self.total_cost,
110
+ "start_time": self.start_time,
111
+ "end_time": self.end_time,
112
+ "duration_ms": self.duration_ms,
113
+ }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: goose-py
3
- Version: 0.9.16
3
+ Version: 0.10.1
4
4
  Summary: A tool for AI workflows based on human-computer collaboration and structured output.
5
5
  Author-email: Nash Taylor <nash@chelle.ai>, Joshua Cook <joshua@chelle.ai>, Michael Sankur <michael@chelle.ai>
6
6
  Requires-Python: >=3.12
@@ -5,15 +5,16 @@ goose/flow.py,sha256=YsZLBa5I1W27_P6LYGWbtFX8ZYx9vJG3KtENYChHm5E,111
5
5
  goose/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  goose/runs.py,sha256=ub-r_gzbUbaIzWXX-jc-dncNxEh6zTfzIkmnDfCSbRI,160
7
7
  goose/task.py,sha256=95rspdxETJoY12IHBl3KjnVIdqQnf1jDKlnGWNWOTvQ,53
8
- goose/_internal/agent.py,sha256=v6v5Sno3Y8jkqjJ1zC6AZL6yFiarTU51AEKQvXFOIGg,5654
8
+ goose/_internal/agent.py,sha256=v_nzQLM_vyKWOc-g1Ke8TwlThkThjraE4HFHNywHMv0,7915
9
9
  goose/_internal/conversation.py,sha256=zvKqLxJSCIIuhD7gjcSFhleYsLabu-ALl9woWFy3mQU,1766
10
10
  goose/_internal/flow.py,sha256=RShMsxgt49g1fZJ3rlwDHtI1j39lZzewx8hZ7DGN5kg,4124
11
- goose/_internal/result.py,sha256=-eZJn-2sPo7rHZ38Sz6IAHXqiJ-Ss39esEoFGimJEBI,155
11
+ goose/_internal/result.py,sha256=vtJMfBxb9skfl8st2tn4hBmEq6qmXiJTme_B5QTgu2M,538
12
12
  goose/_internal/state.py,sha256=U4gM0K4MAlRFTpqenCYHX9TYGuhWVKIfa4yBeZ9Qc9s,7090
13
13
  goose/_internal/store.py,sha256=tWmKfa1-yq1jU6lT3l6kSOmVt2m3H7I1xLMTrxnUDI8,889
14
- goose/_internal/task.py,sha256=w4BW3VDDKGjXb3pqzaxRaWHxLpzDLF2ibdIuJRaT7pc,6211
14
+ goose/_internal/task.py,sha256=mhcmKDTBl993P3HP3PlNvQtl4gMYy4FMYeQ9xrg5aPk,6252
15
15
  goose/_internal/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  goose/_internal/types/agent.py,sha256=g0KD-aPWZlUGBx72AwQd3LeniFxHATeflZ7191QjFZA,2696
17
- goose_py-0.9.16.dist-info/METADATA,sha256=qLZ6JVeZqh9uKO6SLonDuq5UBciuO3_KpqKvJNno-9I,442
18
- goose_py-0.9.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
- goose_py-0.9.16.dist-info/RECORD,,
17
+ goose/_internal/types/telemetry.py,sha256=7zeqyDDxf95puirNM6Gr9VFuxoDshXcV1__V0tiMswE,3663
18
+ goose_py-0.10.1.dist-info/METADATA,sha256=L2yzL8ZW09_75wmrK5YSeEZ2H0RkrODL0zWm1nWW-uA,442
19
+ goose_py-0.10.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
20
+ goose_py-0.10.1.dist-info/RECORD,,