grasp_agents 0.2.11__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +15 -14
- grasp_agents/cloud_llm.py +118 -131
- grasp_agents/comm_processor.py +201 -0
- grasp_agents/generics_utils.py +15 -7
- grasp_agents/llm.py +60 -31
- grasp_agents/llm_agent.py +229 -273
- grasp_agents/llm_agent_memory.py +58 -0
- grasp_agents/llm_policy_executor.py +482 -0
- grasp_agents/memory.py +20 -134
- grasp_agents/message_history.py +140 -0
- grasp_agents/openai/__init__.py +54 -36
- grasp_agents/openai/completion_chunk_converters.py +78 -0
- grasp_agents/openai/completion_converters.py +53 -30
- grasp_agents/openai/content_converters.py +13 -14
- grasp_agents/openai/converters.py +44 -68
- grasp_agents/openai/message_converters.py +58 -72
- grasp_agents/openai/openai_llm.py +101 -42
- grasp_agents/openai/tool_converters.py +24 -19
- grasp_agents/packet.py +24 -0
- grasp_agents/packet_pool.py +91 -0
- grasp_agents/printer.py +29 -15
- grasp_agents/processor.py +194 -0
- grasp_agents/prompt_builder.py +175 -192
- grasp_agents/run_context.py +20 -37
- grasp_agents/typing/completion.py +58 -12
- grasp_agents/typing/completion_chunk.py +173 -0
- grasp_agents/typing/converters.py +8 -12
- grasp_agents/typing/events.py +86 -0
- grasp_agents/typing/io.py +4 -13
- grasp_agents/typing/message.py +12 -50
- grasp_agents/typing/tool.py +52 -26
- grasp_agents/usage_tracker.py +6 -6
- grasp_agents/utils.py +3 -3
- grasp_agents/workflow/looped_workflow.py +132 -0
- grasp_agents/workflow/parallel_processor.py +95 -0
- grasp_agents/workflow/sequential_workflow.py +66 -0
- grasp_agents/workflow/workflow_processor.py +78 -0
- {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.1.dist-info}/METADATA +41 -50
- grasp_agents-0.3.1.dist-info/RECORD +51 -0
- grasp_agents/agent_message.py +0 -27
- grasp_agents/agent_message_pool.py +0 -92
- grasp_agents/base_agent.py +0 -51
- grasp_agents/comm_agent.py +0 -217
- grasp_agents/llm_agent_state.py +0 -79
- grasp_agents/tool_orchestrator.py +0 -203
- grasp_agents/workflow/looped_agent.py +0 -134
- grasp_agents/workflow/sequential_agent.py +0 -72
- grasp_agents/workflow/workflow_agent.py +0 -88
- grasp_agents-0.2.11.dist-info/RECORD +0 -46
- {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.1.dist-info}/WHEEL +0 -0
- {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.1.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/prompt_builder.py
CHANGED
@@ -1,251 +1,234 @@
|
|
1
|
-
|
2
|
-
from
|
3
|
-
from typing import ClassVar, Generic, Protocol,
|
1
|
+
import json
|
2
|
+
from collections.abc import Mapping, Sequence
|
3
|
+
from typing import ClassVar, Generic, Protocol, TypeAlias
|
4
4
|
|
5
5
|
from pydantic import BaseModel, TypeAdapter
|
6
6
|
|
7
7
|
from .generics_utils import AutoInstanceAttributesMixin
|
8
|
-
from .run_context import CtxT,
|
9
|
-
from .typing.content import ImageData
|
10
|
-
from .typing.io import
|
11
|
-
InT,
|
12
|
-
LLMFormattedArgs,
|
13
|
-
LLMFormattedSystemArgs,
|
14
|
-
LLMPrompt,
|
15
|
-
LLMPromptArgs,
|
16
|
-
)
|
8
|
+
from .run_context import CtxT, RunContext
|
9
|
+
from .typing.content import Content, ImageData
|
10
|
+
from .typing.io import InT_contra, LLMPrompt, LLMPromptArgs
|
17
11
|
from .typing.message import UserMessage
|
18
12
|
|
19
13
|
|
20
|
-
class
|
21
|
-
pass
|
22
|
-
|
23
|
-
|
24
|
-
class FormatSystemArgsHandler(Protocol[CtxT]):
|
14
|
+
class MakeSystemPromptHandler(Protocol[CtxT]):
|
25
15
|
def __call__(
|
26
16
|
self,
|
27
|
-
sys_args: LLMPromptArgs,
|
17
|
+
sys_args: LLMPromptArgs | None,
|
28
18
|
*,
|
29
|
-
ctx:
|
30
|
-
) ->
|
19
|
+
ctx: RunContext[CtxT] | None,
|
20
|
+
) -> str: ...
|
31
21
|
|
32
22
|
|
33
|
-
class
|
23
|
+
class MakeInputContentHandler(Protocol[InT_contra, CtxT]):
|
34
24
|
def __call__(
|
35
25
|
self,
|
36
26
|
*,
|
37
|
-
|
38
|
-
|
27
|
+
in_args: InT_contra | None,
|
28
|
+
usr_args: LLMPromptArgs | None,
|
39
29
|
batch_idx: int,
|
40
|
-
ctx:
|
41
|
-
) ->
|
30
|
+
ctx: RunContext[CtxT] | None,
|
31
|
+
) -> Content: ...
|
32
|
+
|
42
33
|
|
34
|
+
PromptArgumentType: TypeAlias = str | bool | int | ImageData
|
43
35
|
|
44
|
-
|
36
|
+
|
37
|
+
class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
|
45
38
|
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {0: "_in_type"}
|
46
39
|
|
47
40
|
def __init__(
|
48
41
|
self,
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
sys_args_schema: type[LLMPromptArgs],
|
53
|
-
usr_args_schema: type[LLMPromptArgs],
|
42
|
+
agent_name: str,
|
43
|
+
sys_prompt_template: LLMPrompt | None,
|
44
|
+
in_prompt_template: LLMPrompt | None,
|
45
|
+
sys_args_schema: type[LLMPromptArgs] | None = None,
|
46
|
+
usr_args_schema: type[LLMPromptArgs] | None = None,
|
54
47
|
):
|
55
|
-
self._in_type: type[
|
48
|
+
self._in_type: type[InT_contra]
|
56
49
|
super().__init__()
|
57
50
|
|
58
|
-
self.
|
59
|
-
self.
|
60
|
-
self.
|
51
|
+
self._agent_name = agent_name
|
52
|
+
self.sys_prompt_template = sys_prompt_template
|
53
|
+
self.in_prompt_template = in_prompt_template
|
61
54
|
self.sys_args_schema = sys_args_schema
|
62
55
|
self.usr_args_schema = usr_args_schema
|
63
|
-
self.
|
64
|
-
self.
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
def _format_sys_args(
|
69
|
-
self,
|
70
|
-
sys_args: LLMPromptArgs,
|
71
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
72
|
-
) -> LLMFormattedSystemArgs:
|
73
|
-
if self.format_sys_args_impl:
|
74
|
-
return self.format_sys_args_impl(sys_args=sys_args, ctx=ctx)
|
75
|
-
|
76
|
-
return sys_args.model_dump(exclude_unset=True)
|
77
|
-
|
78
|
-
def _format_in_args(
|
79
|
-
self,
|
80
|
-
*,
|
81
|
-
usr_args: LLMPromptArgs,
|
82
|
-
in_args: InT,
|
83
|
-
batch_idx: int = 0,
|
84
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
85
|
-
) -> LLMFormattedArgs:
|
86
|
-
if self.format_in_args_impl:
|
87
|
-
return self.format_in_args_impl(
|
88
|
-
usr_args=usr_args, in_args=in_args, batch_idx=batch_idx, ctx=ctx
|
89
|
-
)
|
90
|
-
|
91
|
-
if not isinstance(in_args, BaseModel) and in_args is not None:
|
92
|
-
raise TypeError(
|
93
|
-
"Cannot apply default formatting to non-BaseModel received arguments."
|
94
|
-
)
|
95
|
-
|
96
|
-
usr_args_ = usr_args
|
97
|
-
in_args_ = DummySchema() if in_args is None else in_args
|
98
|
-
|
99
|
-
usr_args_dump = usr_args_.model_dump(exclude_unset=True)
|
100
|
-
in_args_dump = in_args_.model_dump(exclude={"selected_recipient_ids"})
|
56
|
+
self.make_sys_prompt_impl: MakeSystemPromptHandler[CtxT] | None = None
|
57
|
+
self.make_in_content_impl: MakeInputContentHandler[InT_contra, CtxT] | None = (
|
58
|
+
None
|
59
|
+
)
|
101
60
|
|
102
|
-
|
61
|
+
self._in_args_type_adapter: TypeAdapter[InT_contra] = TypeAdapter(self._in_type)
|
103
62
|
|
104
63
|
def make_sys_prompt(
|
105
|
-
self,
|
106
|
-
|
107
|
-
|
108
|
-
ctx: RunContextWrapper[CtxT] | None,
|
109
|
-
) -> LLMPrompt | None:
|
110
|
-
if self.sys_prompt is None:
|
64
|
+
self, sys_args: LLMPromptArgs | None = None, ctx: RunContext[CtxT] | None = None
|
65
|
+
) -> str | None:
|
66
|
+
if self.sys_prompt_template is None:
|
111
67
|
return None
|
112
|
-
val_sys_args = self.sys_args_schema.model_validate(sys_args)
|
113
|
-
fmt_sys_args = self._format_sys_args(val_sys_args, ctx=ctx)
|
114
68
|
|
115
|
-
|
69
|
+
val_sys_args = sys_args
|
70
|
+
if sys_args is not None:
|
71
|
+
if self.sys_args_schema is not None:
|
72
|
+
val_sys_args = self.sys_args_schema.model_validate(sys_args)
|
73
|
+
else:
|
74
|
+
raise TypeError(
|
75
|
+
"System prompt template is set, but system arguments schema is not "
|
76
|
+
"provided."
|
77
|
+
)
|
116
78
|
|
117
|
-
|
118
|
-
|
79
|
+
if self.make_sys_prompt_impl:
|
80
|
+
return self.make_sys_prompt_impl(sys_args=val_sys_args, ctx=ctx)
|
119
81
|
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
return [UserMessage.from_content_parts(content_parts, model_id=self._agent_id)]
|
82
|
+
sys_args_dict = (
|
83
|
+
val_sys_args.model_dump(exclude_unset=True) if val_sys_args else {}
|
84
|
+
)
|
124
85
|
|
125
|
-
|
126
|
-
self, in_args_batch: Sequence[InT]
|
127
|
-
) -> Sequence[UserMessage]:
|
128
|
-
return [
|
129
|
-
UserMessage.from_text(
|
130
|
-
self._in_args_type_adapter.dump_json(
|
131
|
-
inp,
|
132
|
-
exclude_unset=True,
|
133
|
-
indent=2,
|
134
|
-
exclude={"selected_recipient_ids"},
|
135
|
-
warnings="error",
|
136
|
-
).decode("utf-8"),
|
137
|
-
model_id=self._agent_id,
|
138
|
-
)
|
139
|
-
for inp in in_args_batch
|
140
|
-
]
|
86
|
+
return self.sys_prompt_template.format(**sys_args_dict)
|
141
87
|
|
142
|
-
def
|
88
|
+
def make_in_content(
|
143
89
|
self,
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
90
|
+
*,
|
91
|
+
in_args: InT_contra | None,
|
92
|
+
usr_args: LLMPromptArgs | None,
|
93
|
+
batch_idx: int = 0,
|
94
|
+
ctx: RunContext[CtxT] | None = None,
|
95
|
+
) -> Content:
|
96
|
+
val_in_args, val_usr_args = self._validate_prompt_args(
|
97
|
+
in_args=in_args, usr_args=usr_args
|
151
98
|
)
|
152
99
|
|
153
|
-
|
154
|
-
self.
|
155
|
-
|
156
|
-
val_in_args_batch_ = [
|
157
|
-
self._in_args_type_adapter.validate_python(inp) for inp in in_args_batch_
|
158
|
-
]
|
159
|
-
|
160
|
-
formatted_in_args_batch = [
|
161
|
-
self._format_in_args(
|
162
|
-
usr_args=val_usr_args, in_args=val_in_args, batch_idx=i, ctx=ctx
|
100
|
+
if self.make_in_content_impl:
|
101
|
+
return self.make_in_content_impl(
|
102
|
+
in_args=val_in_args, usr_args=val_usr_args, batch_idx=batch_idx, ctx=ctx
|
163
103
|
)
|
164
|
-
for i, (val_usr_args, val_in_args) in enumerate(
|
165
|
-
zip(val_usr_args_batch_, val_in_args_batch_, strict=False)
|
166
|
-
)
|
167
|
-
]
|
168
104
|
|
169
|
-
|
170
|
-
|
171
|
-
|
105
|
+
combined_args = self._combine_args(in_args=val_in_args, usr_args=val_usr_args)
|
106
|
+
if isinstance(combined_args, str):
|
107
|
+
return Content.from_text(combined_args)
|
108
|
+
|
109
|
+
if self.in_prompt_template is not None:
|
110
|
+
return Content.from_formatted_prompt(
|
111
|
+
self.in_prompt_template, prompt_args=combined_args
|
172
112
|
)
|
173
|
-
|
174
|
-
|
113
|
+
|
114
|
+
return Content.from_text(json.dumps(combined_args, indent=2))
|
175
115
|
|
176
116
|
def make_user_messages(
|
177
117
|
self,
|
178
118
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
179
|
-
|
180
|
-
usr_args: LLMPromptArgs |
|
181
|
-
|
182
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
119
|
+
in_args_batch: Sequence[InT_contra] | None = None,
|
120
|
+
usr_args: LLMPromptArgs | None = None,
|
121
|
+
ctx: RunContext[CtxT] | None = None,
|
183
122
|
) -> Sequence[UserMessage]:
|
184
|
-
|
185
|
-
if chat_inputs is not None or entry_point:
|
186
|
-
"""
|
187
|
-
* If chat inputs are provided, use them instead of the predefined
|
188
|
-
input prompt template
|
189
|
-
* In a multi-agent system, the predefined input prompt is used to
|
190
|
-
construct agent inputs using the combination of received
|
191
|
-
and user arguments.
|
192
|
-
However, the first agent run (entry point) has no received
|
193
|
-
messages, so we use the chat inputs directly, if provided.
|
194
|
-
"""
|
123
|
+
if chat_inputs:
|
195
124
|
if isinstance(chat_inputs, LLMPrompt):
|
196
125
|
return self._usr_messages_from_text(chat_inputs)
|
126
|
+
return self._usr_messages_from_content_parts(chat_inputs)
|
197
127
|
|
198
|
-
|
199
|
-
|
128
|
+
in_content_batch = [
|
129
|
+
self.make_in_content(
|
130
|
+
in_args=in_args, usr_args=usr_args, batch_idx=i, ctx=ctx
|
131
|
+
)
|
132
|
+
for i, in_args in enumerate(in_args_batch or [None])
|
133
|
+
]
|
134
|
+
return [
|
135
|
+
UserMessage(content=in_content, name=self._agent_name)
|
136
|
+
for in_content in in_content_batch
|
137
|
+
]
|
200
138
|
|
201
|
-
|
202
|
-
|
203
|
-
in_args if (isinstance(in_args, Sequence) or not in_args) else [in_args],
|
204
|
-
)
|
139
|
+
def _usr_messages_from_text(self, text: str) -> list[UserMessage]:
|
140
|
+
return [UserMessage.from_text(text, name=self._agent_name)]
|
205
141
|
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
usr_args_batch = cast(
|
211
|
-
"Sequence[LLMPromptArgs] | None",
|
212
|
-
(
|
213
|
-
usr_args
|
214
|
-
if (isinstance(usr_args, Sequence) or not usr_args)
|
215
|
-
else [usr_args]
|
216
|
-
),
|
217
|
-
)
|
142
|
+
def _usr_messages_from_content_parts(
|
143
|
+
self, content_parts: Sequence[str | ImageData]
|
144
|
+
) -> list[UserMessage]:
|
145
|
+
return [UserMessage.from_content_parts(content_parts, name=self._agent_name)]
|
218
146
|
|
219
|
-
|
220
|
-
|
221
|
-
|
147
|
+
def _validate_prompt_args(
|
148
|
+
self,
|
149
|
+
*,
|
150
|
+
in_args: InT_contra | None,
|
151
|
+
usr_args: LLMPromptArgs | None,
|
152
|
+
) -> tuple[InT_contra | None, LLMPromptArgs | None]:
|
153
|
+
val_usr_args = usr_args
|
154
|
+
if usr_args is not None:
|
155
|
+
if self.in_prompt_template is None:
|
156
|
+
raise TypeError(
|
157
|
+
"Input prompt template is not set, but user arguments are provided."
|
158
|
+
)
|
159
|
+
if self.usr_args_schema is None:
|
160
|
+
raise TypeError(
|
161
|
+
"User arguments schema is not provided, but user arguments are "
|
162
|
+
"given."
|
163
|
+
)
|
164
|
+
val_usr_args = self.usr_args_schema.model_validate(usr_args)
|
165
|
+
|
166
|
+
val_in_args = in_args
|
167
|
+
if in_args is not None:
|
168
|
+
val_in_args = self._in_args_type_adapter.validate_python(in_args)
|
169
|
+
if isinstance(val_in_args, BaseModel):
|
170
|
+
_, has_image = self._format_pydantic_prompt_args(val_in_args)
|
171
|
+
if has_image and self.in_prompt_template is None:
|
172
|
+
raise TypeError(
|
173
|
+
"BaseModel input arguments contain ImageData, but input prompt "
|
174
|
+
"template is not set. Cannot format input arguments."
|
175
|
+
)
|
176
|
+
elif self.in_prompt_template is not None:
|
222
177
|
raise TypeError(
|
223
178
|
"Cannot use the input prompt template with "
|
224
|
-
"non-BaseModel
|
179
|
+
"non-BaseModel input arguments."
|
225
180
|
)
|
226
|
-
return self._usr_messages_from_prompt_template(
|
227
|
-
in_prompt=self.in_prompt,
|
228
|
-
usr_args_batch=usr_args_batch,
|
229
|
-
in_args_batch=in_args_batch,
|
230
|
-
ctx=ctx,
|
231
|
-
)
|
232
181
|
|
233
|
-
return
|
182
|
+
return val_in_args, val_usr_args
|
183
|
+
|
184
|
+
@staticmethod
|
185
|
+
def _format_pydantic_prompt_args(
|
186
|
+
inp: BaseModel,
|
187
|
+
) -> tuple[dict[str, PromptArgumentType], bool]:
|
188
|
+
formatted_args: dict[str, PromptArgumentType] = {}
|
189
|
+
contains_image_data = False
|
190
|
+
for field in type(inp).model_fields:
|
191
|
+
if field == "selected_recipients":
|
192
|
+
continue
|
193
|
+
|
194
|
+
val = getattr(inp, field)
|
195
|
+
if isinstance(val, (int, str, bool)):
|
196
|
+
formatted_args[field] = val
|
197
|
+
elif isinstance(val, ImageData):
|
198
|
+
formatted_args[field] = val
|
199
|
+
contains_image_data = True
|
200
|
+
elif isinstance(val, BaseModel):
|
201
|
+
formatted_args[field] = val.model_dump_json(indent=2, warnings="error")
|
202
|
+
else:
|
203
|
+
raise TypeError(
|
204
|
+
f"Field '{field}' in prompt arguments must be of type "
|
205
|
+
"int, str, bool, BaseModel, or ImageData."
|
206
|
+
)
|
207
|
+
|
208
|
+
return formatted_args, contains_image_data
|
234
209
|
|
235
|
-
def
|
210
|
+
def _combine_args(
|
236
211
|
self,
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
if
|
249
|
-
|
250
|
-
|
251
|
-
|
212
|
+
*,
|
213
|
+
in_args: InT_contra | None,
|
214
|
+
usr_args: LLMPromptArgs | None,
|
215
|
+
) -> Mapping[str, PromptArgumentType] | str:
|
216
|
+
fmt_usr_args, _ = (
|
217
|
+
self._format_pydantic_prompt_args(usr_args) if usr_args else ({}, False)
|
218
|
+
)
|
219
|
+
|
220
|
+
if in_args is None:
|
221
|
+
return fmt_usr_args
|
222
|
+
|
223
|
+
if isinstance(in_args, BaseModel):
|
224
|
+
fmt_in_args, _ = self._format_pydantic_prompt_args(in_args)
|
225
|
+
return fmt_in_args | fmt_usr_args
|
226
|
+
|
227
|
+
combined_args_str = self._in_args_type_adapter.dump_json(
|
228
|
+
in_args, indent=2, warnings="error"
|
229
|
+
).decode("utf-8")
|
230
|
+
if usr_args is not None:
|
231
|
+
fmt_usr_args_str = usr_args.model_dump_json(indent=2, warnings="error")
|
232
|
+
combined_args_str += "\n" + fmt_usr_args_str
|
233
|
+
|
234
|
+
return combined_args_str
|
grasp_agents/run_context.py
CHANGED
@@ -1,58 +1,39 @@
|
|
1
|
-
from collections
|
2
|
-
from
|
1
|
+
from collections import defaultdict
|
2
|
+
from collections.abc import Mapping
|
3
|
+
from typing import Any, Generic, TypeVar
|
3
4
|
from uuid import uuid4
|
4
5
|
|
5
6
|
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
6
7
|
|
7
|
-
from .
|
8
|
-
|
9
|
-
from .
|
10
|
-
|
11
|
-
AgentState,
|
12
|
-
InT,
|
13
|
-
LLMPrompt,
|
14
|
-
LLMPromptArgs,
|
15
|
-
OutT,
|
16
|
-
StateT,
|
17
|
-
)
|
8
|
+
from grasp_agents.typing.completion import Completion
|
9
|
+
|
10
|
+
from .printer import ColoringMode, Printer
|
11
|
+
from .typing.io import LLMPromptArgs, ProcName
|
18
12
|
from .usage_tracker import UsageTracker
|
19
13
|
|
20
14
|
|
21
15
|
class RunArgs(BaseModel):
|
22
16
|
sys: LLMPromptArgs = Field(default_factory=LLMPromptArgs)
|
23
|
-
usr: LLMPromptArgs
|
17
|
+
usr: LLMPromptArgs = Field(default_factory=LLMPromptArgs)
|
24
18
|
|
25
19
|
model_config = ConfigDict(extra="forbid")
|
26
20
|
|
27
21
|
|
28
|
-
class InteractionRecord(BaseModel, Generic[InT, OutT, StateT]):
|
29
|
-
source_id: str
|
30
|
-
recipient_ids: Sequence[AgentID]
|
31
|
-
state: StateT
|
32
|
-
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None
|
33
|
-
sys_prompt: LLMPrompt | None = None
|
34
|
-
in_prompt: LLMPrompt | None = None
|
35
|
-
sys_args: LLMPromptArgs | None = None
|
36
|
-
usr_args: LLMPromptArgs | Sequence[LLMPromptArgs] | None = None
|
37
|
-
in_args: InT | Sequence[InT] | None = None
|
38
|
-
outputs: Sequence[OutT]
|
39
|
-
|
40
|
-
model_config = ConfigDict(extra="forbid", frozen=True)
|
41
|
-
|
42
|
-
|
43
|
-
InteractionHistory: TypeAlias = list[InteractionRecord[Any, Any, AgentState]]
|
44
|
-
|
45
|
-
|
46
22
|
CtxT = TypeVar("CtxT")
|
47
23
|
|
48
24
|
|
49
|
-
class
|
50
|
-
context: CtxT | None = None
|
25
|
+
class RunContext(BaseModel, Generic[CtxT]):
|
51
26
|
run_id: str = Field(default_factory=lambda: str(uuid4())[:8], frozen=True)
|
52
|
-
|
53
|
-
|
27
|
+
|
28
|
+
state: CtxT | None = None
|
29
|
+
|
30
|
+
run_args: dict[ProcName, RunArgs] = Field(default_factory=dict)
|
31
|
+
completions: Mapping[ProcName, list[Completion]] = Field(
|
32
|
+
default_factory=lambda: defaultdict(list)
|
33
|
+
)
|
54
34
|
|
55
35
|
print_messages: bool = False
|
36
|
+
color_messages_by: ColoringMode = "role"
|
56
37
|
|
57
38
|
_usage_tracker: UsageTracker = PrivateAttr()
|
58
39
|
_printer: Printer = PrivateAttr()
|
@@ -60,7 +41,9 @@ class RunContextWrapper(BaseModel, Generic[CtxT]):
|
|
60
41
|
def model_post_init(self, context: Any) -> None: # noqa: ARG002
|
61
42
|
self._usage_tracker = UsageTracker(source_id=self.run_id)
|
62
43
|
self._printer = Printer(
|
63
|
-
source_id=self.run_id,
|
44
|
+
source_id=self.run_id,
|
45
|
+
print_messages=self.print_messages,
|
46
|
+
color_by=self.color_messages_by,
|
64
47
|
)
|
65
48
|
|
66
49
|
@property
|
@@ -1,14 +1,58 @@
|
|
1
|
-
|
1
|
+
import time
|
2
|
+
from typing import Literal, TypeAlias
|
3
|
+
from uuid import uuid4
|
2
4
|
|
3
|
-
from
|
5
|
+
from openai.types.chat.chat_completion import ChoiceLogprobs as CompletionChoiceLogprobs
|
6
|
+
from pydantic import BaseModel, Field, NonNegativeFloat, NonNegativeInt
|
4
7
|
|
5
8
|
from .message import AssistantMessage
|
6
9
|
|
10
|
+
FinishReason: TypeAlias = Literal[
|
11
|
+
"stop", "length", "tool_calls", "content_filter", "function_call"
|
12
|
+
]
|
13
|
+
|
14
|
+
|
15
|
+
class Usage(BaseModel):
|
16
|
+
input_tokens: NonNegativeInt = 0
|
17
|
+
output_tokens: NonNegativeInt = 0
|
18
|
+
reasoning_tokens: NonNegativeInt | None = None
|
19
|
+
cached_tokens: NonNegativeInt | None = None
|
20
|
+
cost: NonNegativeFloat | None = None
|
21
|
+
|
22
|
+
def __add__(self, add_usage: "Usage") -> "Usage":
|
23
|
+
input_tokens = self.input_tokens + add_usage.input_tokens
|
24
|
+
output_tokens = self.output_tokens + add_usage.output_tokens
|
25
|
+
if self.reasoning_tokens is not None or add_usage.reasoning_tokens is not None:
|
26
|
+
reasoning_tokens = (self.reasoning_tokens or 0) + (
|
27
|
+
add_usage.reasoning_tokens or 0
|
28
|
+
)
|
29
|
+
else:
|
30
|
+
reasoning_tokens = None
|
31
|
+
|
32
|
+
if self.cached_tokens is not None or add_usage.cached_tokens is not None:
|
33
|
+
cached_tokens = (self.cached_tokens or 0) + (add_usage.cached_tokens or 0)
|
34
|
+
else:
|
35
|
+
cached_tokens = None
|
36
|
+
|
37
|
+
cost = (
|
38
|
+
(self.cost or 0.0) + add_usage.cost
|
39
|
+
if (add_usage.cost is not None)
|
40
|
+
else None
|
41
|
+
)
|
42
|
+
return Usage(
|
43
|
+
input_tokens=input_tokens,
|
44
|
+
output_tokens=output_tokens,
|
45
|
+
reasoning_tokens=reasoning_tokens,
|
46
|
+
cached_tokens=cached_tokens,
|
47
|
+
cost=cost,
|
48
|
+
)
|
49
|
+
|
7
50
|
|
8
51
|
class CompletionChoice(BaseModel):
|
9
|
-
# TODO: add fields
|
10
52
|
message: AssistantMessage
|
11
|
-
finish_reason:
|
53
|
+
finish_reason: FinishReason | None
|
54
|
+
index: int
|
55
|
+
logprobs: CompletionChoiceLogprobs | None = None
|
12
56
|
|
13
57
|
|
14
58
|
class CompletionError(BaseModel):
|
@@ -17,14 +61,16 @@ class CompletionError(BaseModel):
|
|
17
61
|
code: int
|
18
62
|
|
19
63
|
|
20
|
-
class Completion(BaseModel
|
21
|
-
|
64
|
+
class Completion(BaseModel):
|
65
|
+
id: str = Field(default_factory=lambda: str(uuid4())[:8])
|
66
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
67
|
+
model: str
|
68
|
+
name: str | None = None
|
69
|
+
system_fingerprint: str | None = None
|
22
70
|
choices: list[CompletionChoice]
|
23
|
-
|
71
|
+
usage: Usage | None = None
|
24
72
|
error: CompletionError | None = None
|
25
73
|
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
delta: str | None = None
|
30
|
-
model_id: str | None = None
|
74
|
+
@property
|
75
|
+
def messages(self) -> list[AssistantMessage]:
|
76
|
+
return [choice.message for choice in self.choices if choice.message]
|