grasp_agents 0.2.10__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +15 -14
- grasp_agents/cloud_llm.py +118 -131
- grasp_agents/comm_processor.py +201 -0
- grasp_agents/generics_utils.py +15 -7
- grasp_agents/llm.py +60 -31
- grasp_agents/llm_agent.py +229 -278
- grasp_agents/llm_agent_memory.py +58 -0
- grasp_agents/llm_policy_executor.py +482 -0
- grasp_agents/memory.py +20 -134
- grasp_agents/message_history.py +140 -0
- grasp_agents/openai/__init__.py +54 -36
- grasp_agents/openai/completion_chunk_converters.py +78 -0
- grasp_agents/openai/completion_converters.py +53 -30
- grasp_agents/openai/content_converters.py +13 -14
- grasp_agents/openai/converters.py +44 -68
- grasp_agents/openai/message_converters.py +58 -72
- grasp_agents/openai/openai_llm.py +101 -42
- grasp_agents/openai/tool_converters.py +24 -19
- grasp_agents/packet.py +24 -0
- grasp_agents/packet_pool.py +91 -0
- grasp_agents/printer.py +29 -15
- grasp_agents/processor.py +194 -0
- grasp_agents/prompt_builder.py +173 -176
- grasp_agents/run_context.py +21 -41
- grasp_agents/typing/completion.py +58 -12
- grasp_agents/typing/completion_chunk.py +173 -0
- grasp_agents/typing/converters.py +8 -12
- grasp_agents/typing/events.py +86 -0
- grasp_agents/typing/io.py +4 -13
- grasp_agents/typing/message.py +12 -50
- grasp_agents/typing/tool.py +52 -26
- grasp_agents/usage_tracker.py +6 -6
- grasp_agents/utils.py +3 -3
- grasp_agents/workflow/looped_workflow.py +132 -0
- grasp_agents/workflow/parallel_processor.py +95 -0
- grasp_agents/workflow/sequential_workflow.py +66 -0
- grasp_agents/workflow/workflow_processor.py +78 -0
- {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/METADATA +41 -50
- grasp_agents-0.3.1.dist-info/RECORD +51 -0
- grasp_agents/agent_message.py +0 -27
- grasp_agents/agent_message_pool.py +0 -92
- grasp_agents/base_agent.py +0 -51
- grasp_agents/comm_agent.py +0 -217
- grasp_agents/llm_agent_state.py +0 -79
- grasp_agents/tool_orchestrator.py +0 -203
- grasp_agents/workflow/looped_agent.py +0 -120
- grasp_agents/workflow/sequential_agent.py +0 -63
- grasp_agents/workflow/workflow_agent.py +0 -73
- grasp_agents-0.2.10.dist-info/RECORD +0 -46
- {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/WHEEL +0 -0
- {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/prompt_builder.py
CHANGED
@@ -1,237 +1,234 @@
|
|
1
|
-
|
2
|
-
from
|
3
|
-
from typing import ClassVar, Generic, Protocol
|
1
|
+
import json
|
2
|
+
from collections.abc import Mapping, Sequence
|
3
|
+
from typing import ClassVar, Generic, Protocol, TypeAlias
|
4
4
|
|
5
5
|
from pydantic import BaseModel, TypeAdapter
|
6
6
|
|
7
7
|
from .generics_utils import AutoInstanceAttributesMixin
|
8
|
-
from .run_context import CtxT,
|
9
|
-
from .typing.content import ImageData
|
10
|
-
from .typing.io import
|
11
|
-
InT,
|
12
|
-
LLMFormattedArgs,
|
13
|
-
LLMFormattedSystemArgs,
|
14
|
-
LLMPrompt,
|
15
|
-
LLMPromptArgs,
|
16
|
-
)
|
8
|
+
from .run_context import CtxT, RunContext
|
9
|
+
from .typing.content import Content, ImageData
|
10
|
+
from .typing.io import InT_contra, LLMPrompt, LLMPromptArgs
|
17
11
|
from .typing.message import UserMessage
|
18
12
|
|
19
13
|
|
20
|
-
class
|
21
|
-
pass
|
22
|
-
|
23
|
-
|
24
|
-
class FormatSystemArgsHandler(Protocol[CtxT]):
|
14
|
+
class MakeSystemPromptHandler(Protocol[CtxT]):
|
25
15
|
def __call__(
|
26
16
|
self,
|
27
|
-
sys_args: LLMPromptArgs,
|
17
|
+
sys_args: LLMPromptArgs | None,
|
28
18
|
*,
|
29
|
-
ctx:
|
30
|
-
) ->
|
19
|
+
ctx: RunContext[CtxT] | None,
|
20
|
+
) -> str: ...
|
31
21
|
|
32
22
|
|
33
|
-
class
|
23
|
+
class MakeInputContentHandler(Protocol[InT_contra, CtxT]):
|
34
24
|
def __call__(
|
35
25
|
self,
|
36
26
|
*,
|
37
|
-
|
38
|
-
|
27
|
+
in_args: InT_contra | None,
|
28
|
+
usr_args: LLMPromptArgs | None,
|
39
29
|
batch_idx: int,
|
40
|
-
ctx:
|
41
|
-
) ->
|
30
|
+
ctx: RunContext[CtxT] | None,
|
31
|
+
) -> Content: ...
|
32
|
+
|
33
|
+
|
34
|
+
PromptArgumentType: TypeAlias = str | bool | int | ImageData
|
42
35
|
|
43
36
|
|
44
|
-
class PromptBuilder(AutoInstanceAttributesMixin, Generic[
|
37
|
+
class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
|
45
38
|
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {0: "_in_type"}
|
46
39
|
|
47
40
|
def __init__(
|
48
41
|
self,
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
sys_args_schema: type[LLMPromptArgs],
|
53
|
-
usr_args_schema: type[LLMPromptArgs],
|
42
|
+
agent_name: str,
|
43
|
+
sys_prompt_template: LLMPrompt | None,
|
44
|
+
in_prompt_template: LLMPrompt | None,
|
45
|
+
sys_args_schema: type[LLMPromptArgs] | None = None,
|
46
|
+
usr_args_schema: type[LLMPromptArgs] | None = None,
|
54
47
|
):
|
55
|
-
self._in_type: type[
|
48
|
+
self._in_type: type[InT_contra]
|
56
49
|
super().__init__()
|
57
50
|
|
58
|
-
self.
|
59
|
-
self.
|
60
|
-
self.
|
51
|
+
self._agent_name = agent_name
|
52
|
+
self.sys_prompt_template = sys_prompt_template
|
53
|
+
self.in_prompt_template = in_prompt_template
|
61
54
|
self.sys_args_schema = sys_args_schema
|
62
55
|
self.usr_args_schema = usr_args_schema
|
63
|
-
self.
|
64
|
-
self.
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
def _format_sys_args(
|
69
|
-
self,
|
70
|
-
sys_args: LLMPromptArgs,
|
71
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
72
|
-
) -> LLMFormattedSystemArgs:
|
73
|
-
if self.format_sys_args_impl:
|
74
|
-
return self.format_sys_args_impl(sys_args=sys_args, ctx=ctx)
|
56
|
+
self.make_sys_prompt_impl: MakeSystemPromptHandler[CtxT] | None = None
|
57
|
+
self.make_in_content_impl: MakeInputContentHandler[InT_contra, CtxT] | None = (
|
58
|
+
None
|
59
|
+
)
|
75
60
|
|
76
|
-
|
61
|
+
self._in_args_type_adapter: TypeAdapter[InT_contra] = TypeAdapter(self._in_type)
|
77
62
|
|
78
|
-
def
|
79
|
-
self,
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
batch_idx: int = 0,
|
84
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
85
|
-
) -> LLMFormattedArgs:
|
86
|
-
if self.format_in_args_impl:
|
87
|
-
return self.format_in_args_impl(
|
88
|
-
usr_args=usr_args, in_args=in_args, batch_idx=batch_idx, ctx=ctx
|
89
|
-
)
|
63
|
+
def make_sys_prompt(
|
64
|
+
self, sys_args: LLMPromptArgs | None = None, ctx: RunContext[CtxT] | None = None
|
65
|
+
) -> str | None:
|
66
|
+
if self.sys_prompt_template is None:
|
67
|
+
return None
|
90
68
|
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
69
|
+
val_sys_args = sys_args
|
70
|
+
if sys_args is not None:
|
71
|
+
if self.sys_args_schema is not None:
|
72
|
+
val_sys_args = self.sys_args_schema.model_validate(sys_args)
|
73
|
+
else:
|
74
|
+
raise TypeError(
|
75
|
+
"System prompt template is set, but system arguments schema is not "
|
76
|
+
"provided."
|
77
|
+
)
|
95
78
|
|
96
|
-
|
97
|
-
|
79
|
+
if self.make_sys_prompt_impl:
|
80
|
+
return self.make_sys_prompt_impl(sys_args=val_sys_args, ctx=ctx)
|
98
81
|
|
99
|
-
|
100
|
-
|
82
|
+
sys_args_dict = (
|
83
|
+
val_sys_args.model_dump(exclude_unset=True) if val_sys_args else {}
|
84
|
+
)
|
101
85
|
|
102
|
-
return
|
86
|
+
return self.sys_prompt_template.format(**sys_args_dict)
|
103
87
|
|
104
|
-
def
|
88
|
+
def make_in_content(
|
105
89
|
self,
|
106
|
-
sys_args: LLMPromptArgs,
|
107
90
|
*,
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
def _usr_messages_from_text(self, text: str) -> list[UserMessage]:
|
118
|
-
return [UserMessage.from_text(text, model_id=self._agent_id)]
|
119
|
-
|
120
|
-
def _usr_messages_from_content_parts(
|
121
|
-
self, content_parts: Sequence[str | ImageData]
|
122
|
-
) -> Sequence[UserMessage]:
|
123
|
-
return [UserMessage.from_content_parts(content_parts, model_id=self._agent_id)]
|
91
|
+
in_args: InT_contra | None,
|
92
|
+
usr_args: LLMPromptArgs | None,
|
93
|
+
batch_idx: int = 0,
|
94
|
+
ctx: RunContext[CtxT] | None = None,
|
95
|
+
) -> Content:
|
96
|
+
val_in_args, val_usr_args = self._validate_prompt_args(
|
97
|
+
in_args=in_args, usr_args=usr_args
|
98
|
+
)
|
124
99
|
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
return [
|
129
|
-
UserMessage.from_text(
|
130
|
-
self._in_args_type_adapter.dump_json(
|
131
|
-
inp,
|
132
|
-
exclude_unset=True,
|
133
|
-
indent=2,
|
134
|
-
exclude={"selected_recipient_ids"},
|
135
|
-
warnings="error",
|
136
|
-
).decode("utf-8"),
|
137
|
-
model_id=self._agent_id,
|
100
|
+
if self.make_in_content_impl:
|
101
|
+
return self.make_in_content_impl(
|
102
|
+
in_args=val_in_args, usr_args=val_usr_args, batch_idx=batch_idx, ctx=ctx
|
138
103
|
)
|
139
|
-
for inp in in_args_batch
|
140
|
-
]
|
141
104
|
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
usr_args: UserRunArgs | None = None,
|
146
|
-
in_args_batch: Sequence[InT] | None = None,
|
147
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
148
|
-
) -> Sequence[UserMessage]:
|
149
|
-
usr_args_batch_, in_args_batch_ = self._make_batched(usr_args, in_args_batch)
|
150
|
-
|
151
|
-
val_usr_args_batch_ = [
|
152
|
-
self.usr_args_schema.model_validate(u) for u in usr_args_batch_
|
153
|
-
]
|
154
|
-
val_in_args_batch_ = [
|
155
|
-
self._in_args_type_adapter.validate_python(inp) for inp in in_args_batch_
|
156
|
-
]
|
105
|
+
combined_args = self._combine_args(in_args=val_in_args, usr_args=val_usr_args)
|
106
|
+
if isinstance(combined_args, str):
|
107
|
+
return Content.from_text(combined_args)
|
157
108
|
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
)
|
162
|
-
for i, (val_usr_args, val_in_args) in enumerate(
|
163
|
-
zip(val_usr_args_batch_, val_in_args_batch_, strict=False)
|
109
|
+
if self.in_prompt_template is not None:
|
110
|
+
return Content.from_formatted_prompt(
|
111
|
+
self.in_prompt_template, prompt_args=combined_args
|
164
112
|
)
|
165
|
-
]
|
166
113
|
|
167
|
-
return
|
168
|
-
UserMessage.from_formatted_prompt(
|
169
|
-
prompt_template=in_prompt, prompt_args=in_args
|
170
|
-
)
|
171
|
-
for in_args in formatted_in_args_batch
|
172
|
-
]
|
114
|
+
return Content.from_text(json.dumps(combined_args, indent=2))
|
173
115
|
|
174
116
|
def make_user_messages(
|
175
117
|
self,
|
176
118
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
ctx: RunContextWrapper[CtxT] | None = None,
|
119
|
+
in_args_batch: Sequence[InT_contra] | None = None,
|
120
|
+
usr_args: LLMPromptArgs | None = None,
|
121
|
+
ctx: RunContext[CtxT] | None = None,
|
181
122
|
) -> Sequence[UserMessage]:
|
182
|
-
|
183
|
-
if chat_inputs is not None or entry_point:
|
184
|
-
"""
|
185
|
-
* If chat inputs are provided, use them instead of the predefined
|
186
|
-
input prompt template
|
187
|
-
* In a multi-agent system, the predefined input prompt is used to
|
188
|
-
construct agent inputs using the combination of received
|
189
|
-
and user arguments.
|
190
|
-
However, the first agent run (entry point) has no received
|
191
|
-
messages, so we use the chat inputs directly, if provided.
|
192
|
-
"""
|
123
|
+
if chat_inputs:
|
193
124
|
if isinstance(chat_inputs, LLMPrompt):
|
194
125
|
return self._usr_messages_from_text(chat_inputs)
|
126
|
+
return self._usr_messages_from_content_parts(chat_inputs)
|
195
127
|
|
196
|
-
|
197
|
-
|
128
|
+
in_content_batch = [
|
129
|
+
self.make_in_content(
|
130
|
+
in_args=in_args, usr_args=usr_args, batch_idx=i, ctx=ctx
|
131
|
+
)
|
132
|
+
for i, in_args in enumerate(in_args_batch or [None])
|
133
|
+
]
|
134
|
+
return [
|
135
|
+
UserMessage(content=in_content, name=self._agent_name)
|
136
|
+
for in_content in in_content_batch
|
137
|
+
]
|
198
138
|
|
199
|
-
|
200
|
-
|
201
|
-
return self._usr_messages_from_in_args(in_args_batch)
|
139
|
+
def _usr_messages_from_text(self, text: str) -> list[UserMessage]:
|
140
|
+
return [UserMessage.from_text(text, name=self._agent_name)]
|
202
141
|
|
203
|
-
|
204
|
-
|
205
|
-
|
142
|
+
def _usr_messages_from_content_parts(
|
143
|
+
self, content_parts: Sequence[str | ImageData]
|
144
|
+
) -> list[UserMessage]:
|
145
|
+
return [UserMessage.from_content_parts(content_parts, name=self._agent_name)]
|
146
|
+
|
147
|
+
def _validate_prompt_args(
|
148
|
+
self,
|
149
|
+
*,
|
150
|
+
in_args: InT_contra | None,
|
151
|
+
usr_args: LLMPromptArgs | None,
|
152
|
+
) -> tuple[InT_contra | None, LLMPromptArgs | None]:
|
153
|
+
val_usr_args = usr_args
|
154
|
+
if usr_args is not None:
|
155
|
+
if self.in_prompt_template is None:
|
156
|
+
raise TypeError(
|
157
|
+
"Input prompt template is not set, but user arguments are provided."
|
158
|
+
)
|
159
|
+
if self.usr_args_schema is None:
|
160
|
+
raise TypeError(
|
161
|
+
"User arguments schema is not provided, but user arguments are "
|
162
|
+
"given."
|
163
|
+
)
|
164
|
+
val_usr_args = self.usr_args_schema.model_validate(usr_args)
|
165
|
+
|
166
|
+
val_in_args = in_args
|
167
|
+
if in_args is not None:
|
168
|
+
val_in_args = self._in_args_type_adapter.validate_python(in_args)
|
169
|
+
if isinstance(val_in_args, BaseModel):
|
170
|
+
_, has_image = self._format_pydantic_prompt_args(val_in_args)
|
171
|
+
if has_image and self.in_prompt_template is None:
|
172
|
+
raise TypeError(
|
173
|
+
"BaseModel input arguments contain ImageData, but input prompt "
|
174
|
+
"template is not set. Cannot format input arguments."
|
175
|
+
)
|
176
|
+
elif self.in_prompt_template is not None:
|
206
177
|
raise TypeError(
|
207
178
|
"Cannot use the input prompt template with "
|
208
|
-
"non-BaseModel
|
179
|
+
"non-BaseModel input arguments."
|
180
|
+
)
|
181
|
+
|
182
|
+
return val_in_args, val_usr_args
|
183
|
+
|
184
|
+
@staticmethod
|
185
|
+
def _format_pydantic_prompt_args(
|
186
|
+
inp: BaseModel,
|
187
|
+
) -> tuple[dict[str, PromptArgumentType], bool]:
|
188
|
+
formatted_args: dict[str, PromptArgumentType] = {}
|
189
|
+
contains_image_data = False
|
190
|
+
for field in type(inp).model_fields:
|
191
|
+
if field == "selected_recipients":
|
192
|
+
continue
|
193
|
+
|
194
|
+
val = getattr(inp, field)
|
195
|
+
if isinstance(val, (int, str, bool)):
|
196
|
+
formatted_args[field] = val
|
197
|
+
elif isinstance(val, ImageData):
|
198
|
+
formatted_args[field] = val
|
199
|
+
contains_image_data = True
|
200
|
+
elif isinstance(val, BaseModel):
|
201
|
+
formatted_args[field] = val.model_dump_json(indent=2, warnings="error")
|
202
|
+
else:
|
203
|
+
raise TypeError(
|
204
|
+
f"Field '{field}' in prompt arguments must be of type "
|
205
|
+
"int, str, bool, BaseModel, or ImageData."
|
209
206
|
)
|
210
|
-
return self._usr_messages_from_prompt_template(
|
211
|
-
in_prompt=self.in_prompt,
|
212
|
-
usr_args=usr_args,
|
213
|
-
in_args_batch=in_args_batch,
|
214
|
-
ctx=ctx,
|
215
|
-
)
|
216
207
|
|
217
|
-
return
|
208
|
+
return formatted_args, contains_image_data
|
218
209
|
|
219
|
-
def
|
210
|
+
def _combine_args(
|
220
211
|
self,
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
212
|
+
*,
|
213
|
+
in_args: InT_contra | None,
|
214
|
+
usr_args: LLMPromptArgs | None,
|
215
|
+
) -> Mapping[str, PromptArgumentType] | str:
|
216
|
+
fmt_usr_args, _ = (
|
217
|
+
self._format_pydantic_prompt_args(usr_args) if usr_args else ({}, False)
|
226
218
|
)
|
227
|
-
in_args_batch_ = in_args_batch or [DummySchema()]
|
228
219
|
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
if
|
233
|
-
|
234
|
-
|
235
|
-
|
220
|
+
if in_args is None:
|
221
|
+
return fmt_usr_args
|
222
|
+
|
223
|
+
if isinstance(in_args, BaseModel):
|
224
|
+
fmt_in_args, _ = self._format_pydantic_prompt_args(in_args)
|
225
|
+
return fmt_in_args | fmt_usr_args
|
226
|
+
|
227
|
+
combined_args_str = self._in_args_type_adapter.dump_json(
|
228
|
+
in_args, indent=2, warnings="error"
|
229
|
+
).decode("utf-8")
|
230
|
+
if usr_args is not None:
|
231
|
+
fmt_usr_args_str = usr_args.model_dump_json(indent=2, warnings="error")
|
232
|
+
combined_args_str += "\n" + fmt_usr_args_str
|
236
233
|
|
237
|
-
return
|
234
|
+
return combined_args_str
|
grasp_agents/run_context.py
CHANGED
@@ -1,61 +1,39 @@
|
|
1
|
-
from collections
|
2
|
-
from
|
1
|
+
from collections import defaultdict
|
2
|
+
from collections.abc import Mapping
|
3
|
+
from typing import Any, Generic, TypeVar
|
3
4
|
from uuid import uuid4
|
4
5
|
|
5
6
|
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
6
7
|
|
7
|
-
from .
|
8
|
-
from .typing.content import ImageData
|
9
|
-
from .typing.io import (
|
10
|
-
AgentID,
|
11
|
-
AgentState,
|
12
|
-
InT,
|
13
|
-
LLMPrompt,
|
14
|
-
LLMPromptArgs,
|
15
|
-
OutT,
|
16
|
-
StateT,
|
17
|
-
)
|
18
|
-
from .usage_tracker import UsageTracker
|
8
|
+
from grasp_agents.typing.completion import Completion
|
19
9
|
|
20
|
-
|
21
|
-
|
10
|
+
from .printer import ColoringMode, Printer
|
11
|
+
from .typing.io import LLMPromptArgs, ProcName
|
12
|
+
from .usage_tracker import UsageTracker
|
22
13
|
|
23
14
|
|
24
15
|
class RunArgs(BaseModel):
|
25
|
-
sys:
|
26
|
-
usr:
|
16
|
+
sys: LLMPromptArgs = Field(default_factory=LLMPromptArgs)
|
17
|
+
usr: LLMPromptArgs = Field(default_factory=LLMPromptArgs)
|
27
18
|
|
28
19
|
model_config = ConfigDict(extra="forbid")
|
29
20
|
|
30
21
|
|
31
|
-
class InteractionRecord(BaseModel, Generic[InT, OutT, StateT]):
|
32
|
-
source_id: str
|
33
|
-
recipient_ids: Sequence[AgentID]
|
34
|
-
state: StateT
|
35
|
-
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None
|
36
|
-
sys_prompt: LLMPrompt | None = None
|
37
|
-
in_prompt: LLMPrompt | None = None
|
38
|
-
sys_args: SystemRunArgs | None = None
|
39
|
-
usr_args: UserRunArgs | None = None
|
40
|
-
in_args: Sequence[InT] | None = None
|
41
|
-
outputs: Sequence[OutT]
|
42
|
-
|
43
|
-
model_config = ConfigDict(extra="forbid", frozen=True)
|
44
|
-
|
45
|
-
|
46
|
-
InteractionHistory: TypeAlias = list[InteractionRecord[Any, Any, AgentState]]
|
47
|
-
|
48
|
-
|
49
22
|
CtxT = TypeVar("CtxT")
|
50
23
|
|
51
24
|
|
52
|
-
class
|
53
|
-
context: CtxT | None = None
|
25
|
+
class RunContext(BaseModel, Generic[CtxT]):
|
54
26
|
run_id: str = Field(default_factory=lambda: str(uuid4())[:8], frozen=True)
|
55
|
-
|
56
|
-
|
27
|
+
|
28
|
+
state: CtxT | None = None
|
29
|
+
|
30
|
+
run_args: dict[ProcName, RunArgs] = Field(default_factory=dict)
|
31
|
+
completions: Mapping[ProcName, list[Completion]] = Field(
|
32
|
+
default_factory=lambda: defaultdict(list)
|
33
|
+
)
|
57
34
|
|
58
35
|
print_messages: bool = False
|
36
|
+
color_messages_by: ColoringMode = "role"
|
59
37
|
|
60
38
|
_usage_tracker: UsageTracker = PrivateAttr()
|
61
39
|
_printer: Printer = PrivateAttr()
|
@@ -63,7 +41,9 @@ class RunContextWrapper(BaseModel, Generic[CtxT]):
|
|
63
41
|
def model_post_init(self, context: Any) -> None: # noqa: ARG002
|
64
42
|
self._usage_tracker = UsageTracker(source_id=self.run_id)
|
65
43
|
self._printer = Printer(
|
66
|
-
source_id=self.run_id,
|
44
|
+
source_id=self.run_id,
|
45
|
+
print_messages=self.print_messages,
|
46
|
+
color_by=self.color_messages_by,
|
67
47
|
)
|
68
48
|
|
69
49
|
@property
|
@@ -1,14 +1,58 @@
|
|
1
|
-
|
1
|
+
import time
|
2
|
+
from typing import Literal, TypeAlias
|
3
|
+
from uuid import uuid4
|
2
4
|
|
3
|
-
from
|
5
|
+
from openai.types.chat.chat_completion import ChoiceLogprobs as CompletionChoiceLogprobs
|
6
|
+
from pydantic import BaseModel, Field, NonNegativeFloat, NonNegativeInt
|
4
7
|
|
5
8
|
from .message import AssistantMessage
|
6
9
|
|
10
|
+
FinishReason: TypeAlias = Literal[
|
11
|
+
"stop", "length", "tool_calls", "content_filter", "function_call"
|
12
|
+
]
|
13
|
+
|
14
|
+
|
15
|
+
class Usage(BaseModel):
|
16
|
+
input_tokens: NonNegativeInt = 0
|
17
|
+
output_tokens: NonNegativeInt = 0
|
18
|
+
reasoning_tokens: NonNegativeInt | None = None
|
19
|
+
cached_tokens: NonNegativeInt | None = None
|
20
|
+
cost: NonNegativeFloat | None = None
|
21
|
+
|
22
|
+
def __add__(self, add_usage: "Usage") -> "Usage":
|
23
|
+
input_tokens = self.input_tokens + add_usage.input_tokens
|
24
|
+
output_tokens = self.output_tokens + add_usage.output_tokens
|
25
|
+
if self.reasoning_tokens is not None or add_usage.reasoning_tokens is not None:
|
26
|
+
reasoning_tokens = (self.reasoning_tokens or 0) + (
|
27
|
+
add_usage.reasoning_tokens or 0
|
28
|
+
)
|
29
|
+
else:
|
30
|
+
reasoning_tokens = None
|
31
|
+
|
32
|
+
if self.cached_tokens is not None or add_usage.cached_tokens is not None:
|
33
|
+
cached_tokens = (self.cached_tokens or 0) + (add_usage.cached_tokens or 0)
|
34
|
+
else:
|
35
|
+
cached_tokens = None
|
36
|
+
|
37
|
+
cost = (
|
38
|
+
(self.cost or 0.0) + add_usage.cost
|
39
|
+
if (add_usage.cost is not None)
|
40
|
+
else None
|
41
|
+
)
|
42
|
+
return Usage(
|
43
|
+
input_tokens=input_tokens,
|
44
|
+
output_tokens=output_tokens,
|
45
|
+
reasoning_tokens=reasoning_tokens,
|
46
|
+
cached_tokens=cached_tokens,
|
47
|
+
cost=cost,
|
48
|
+
)
|
49
|
+
|
7
50
|
|
8
51
|
class CompletionChoice(BaseModel):
|
9
|
-
# TODO: add fields
|
10
52
|
message: AssistantMessage
|
11
|
-
finish_reason:
|
53
|
+
finish_reason: FinishReason | None
|
54
|
+
index: int
|
55
|
+
logprobs: CompletionChoiceLogprobs | None = None
|
12
56
|
|
13
57
|
|
14
58
|
class CompletionError(BaseModel):
|
@@ -17,14 +61,16 @@ class CompletionError(BaseModel):
|
|
17
61
|
code: int
|
18
62
|
|
19
63
|
|
20
|
-
class Completion(BaseModel
|
21
|
-
|
64
|
+
class Completion(BaseModel):
|
65
|
+
id: str = Field(default_factory=lambda: str(uuid4())[:8])
|
66
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
67
|
+
model: str
|
68
|
+
name: str | None = None
|
69
|
+
system_fingerprint: str | None = None
|
22
70
|
choices: list[CompletionChoice]
|
23
|
-
|
71
|
+
usage: Usage | None = None
|
24
72
|
error: CompletionError | None = None
|
25
73
|
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
delta: str | None = None
|
30
|
-
model_id: str | None = None
|
74
|
+
@property
|
75
|
+
def messages(self) -> list[AssistantMessage]:
|
76
|
+
return [choice.message for choice in self.choices if choice.message]
|