grasp_agents 0.2.10__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. grasp_agents/__init__.py +15 -14
  2. grasp_agents/cloud_llm.py +118 -131
  3. grasp_agents/comm_processor.py +201 -0
  4. grasp_agents/generics_utils.py +15 -7
  5. grasp_agents/llm.py +60 -31
  6. grasp_agents/llm_agent.py +229 -278
  7. grasp_agents/llm_agent_memory.py +58 -0
  8. grasp_agents/llm_policy_executor.py +482 -0
  9. grasp_agents/memory.py +20 -134
  10. grasp_agents/message_history.py +140 -0
  11. grasp_agents/openai/__init__.py +54 -36
  12. grasp_agents/openai/completion_chunk_converters.py +78 -0
  13. grasp_agents/openai/completion_converters.py +53 -30
  14. grasp_agents/openai/content_converters.py +13 -14
  15. grasp_agents/openai/converters.py +44 -68
  16. grasp_agents/openai/message_converters.py +58 -72
  17. grasp_agents/openai/openai_llm.py +101 -42
  18. grasp_agents/openai/tool_converters.py +24 -19
  19. grasp_agents/packet.py +24 -0
  20. grasp_agents/packet_pool.py +91 -0
  21. grasp_agents/printer.py +29 -15
  22. grasp_agents/processor.py +194 -0
  23. grasp_agents/prompt_builder.py +173 -176
  24. grasp_agents/run_context.py +21 -41
  25. grasp_agents/typing/completion.py +58 -12
  26. grasp_agents/typing/completion_chunk.py +173 -0
  27. grasp_agents/typing/converters.py +8 -12
  28. grasp_agents/typing/events.py +86 -0
  29. grasp_agents/typing/io.py +4 -13
  30. grasp_agents/typing/message.py +12 -50
  31. grasp_agents/typing/tool.py +52 -26
  32. grasp_agents/usage_tracker.py +6 -6
  33. grasp_agents/utils.py +3 -3
  34. grasp_agents/workflow/looped_workflow.py +132 -0
  35. grasp_agents/workflow/parallel_processor.py +95 -0
  36. grasp_agents/workflow/sequential_workflow.py +66 -0
  37. grasp_agents/workflow/workflow_processor.py +78 -0
  38. {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/METADATA +41 -50
  39. grasp_agents-0.3.1.dist-info/RECORD +51 -0
  40. grasp_agents/agent_message.py +0 -27
  41. grasp_agents/agent_message_pool.py +0 -92
  42. grasp_agents/base_agent.py +0 -51
  43. grasp_agents/comm_agent.py +0 -217
  44. grasp_agents/llm_agent_state.py +0 -79
  45. grasp_agents/tool_orchestrator.py +0 -203
  46. grasp_agents/workflow/looped_agent.py +0 -120
  47. grasp_agents/workflow/sequential_agent.py +0 -63
  48. grasp_agents/workflow/workflow_agent.py +0 -73
  49. grasp_agents-0.2.10.dist-info/RECORD +0 -46
  50. {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/WHEEL +0 -0
  51. {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,237 +1,234 @@
1
- from collections.abc import Sequence
2
- from copy import deepcopy
3
- from typing import ClassVar, Generic, Protocol
1
+ import json
2
+ from collections.abc import Mapping, Sequence
3
+ from typing import ClassVar, Generic, Protocol, TypeAlias
4
4
 
5
5
  from pydantic import BaseModel, TypeAdapter
6
6
 
7
7
  from .generics_utils import AutoInstanceAttributesMixin
8
- from .run_context import CtxT, RunContextWrapper, UserRunArgs
9
- from .typing.content import ImageData
10
- from .typing.io import (
11
- InT,
12
- LLMFormattedArgs,
13
- LLMFormattedSystemArgs,
14
- LLMPrompt,
15
- LLMPromptArgs,
16
- )
8
+ from .run_context import CtxT, RunContext
9
+ from .typing.content import Content, ImageData
10
+ from .typing.io import InT_contra, LLMPrompt, LLMPromptArgs
17
11
  from .typing.message import UserMessage
18
12
 
19
13
 
20
- class DummySchema(BaseModel):
21
- pass
22
-
23
-
24
- class FormatSystemArgsHandler(Protocol[CtxT]):
14
+ class MakeSystemPromptHandler(Protocol[CtxT]):
25
15
  def __call__(
26
16
  self,
27
- sys_args: LLMPromptArgs,
17
+ sys_args: LLMPromptArgs | None,
28
18
  *,
29
- ctx: RunContextWrapper[CtxT] | None,
30
- ) -> LLMFormattedSystemArgs: ...
19
+ ctx: RunContext[CtxT] | None,
20
+ ) -> str: ...
31
21
 
32
22
 
33
- class FormatInputArgsHandler(Protocol[InT, CtxT]):
23
+ class MakeInputContentHandler(Protocol[InT_contra, CtxT]):
34
24
  def __call__(
35
25
  self,
36
26
  *,
37
- usr_args: LLMPromptArgs,
38
- in_args: InT,
27
+ in_args: InT_contra | None,
28
+ usr_args: LLMPromptArgs | None,
39
29
  batch_idx: int,
40
- ctx: RunContextWrapper[CtxT] | None,
41
- ) -> LLMFormattedArgs: ...
30
+ ctx: RunContext[CtxT] | None,
31
+ ) -> Content: ...
32
+
33
+
34
+ PromptArgumentType: TypeAlias = str | bool | int | ImageData
42
35
 
43
36
 
44
- class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
37
+ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
45
38
  _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {0: "_in_type"}
46
39
 
47
40
  def __init__(
48
41
  self,
49
- agent_id: str,
50
- sys_prompt: LLMPrompt | None,
51
- in_prompt: LLMPrompt | None,
52
- sys_args_schema: type[LLMPromptArgs],
53
- usr_args_schema: type[LLMPromptArgs],
42
+ agent_name: str,
43
+ sys_prompt_template: LLMPrompt | None,
44
+ in_prompt_template: LLMPrompt | None,
45
+ sys_args_schema: type[LLMPromptArgs] | None = None,
46
+ usr_args_schema: type[LLMPromptArgs] | None = None,
54
47
  ):
55
- self._in_type: type[InT]
48
+ self._in_type: type[InT_contra]
56
49
  super().__init__()
57
50
 
58
- self._agent_id = agent_id
59
- self.sys_prompt = sys_prompt
60
- self.in_prompt = in_prompt
51
+ self._agent_name = agent_name
52
+ self.sys_prompt_template = sys_prompt_template
53
+ self.in_prompt_template = in_prompt_template
61
54
  self.sys_args_schema = sys_args_schema
62
55
  self.usr_args_schema = usr_args_schema
63
- self.format_sys_args_impl: FormatSystemArgsHandler[CtxT] | None = None
64
- self.format_in_args_impl: FormatInputArgsHandler[InT, CtxT] | None = None
65
-
66
- self._in_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
67
-
68
- def _format_sys_args(
69
- self,
70
- sys_args: LLMPromptArgs,
71
- ctx: RunContextWrapper[CtxT] | None = None,
72
- ) -> LLMFormattedSystemArgs:
73
- if self.format_sys_args_impl:
74
- return self.format_sys_args_impl(sys_args=sys_args, ctx=ctx)
56
+ self.make_sys_prompt_impl: MakeSystemPromptHandler[CtxT] | None = None
57
+ self.make_in_content_impl: MakeInputContentHandler[InT_contra, CtxT] | None = (
58
+ None
59
+ )
75
60
 
76
- return sys_args.model_dump(exclude_unset=True)
61
+ self._in_args_type_adapter: TypeAdapter[InT_contra] = TypeAdapter(self._in_type)
77
62
 
78
- def _format_in_args(
79
- self,
80
- *,
81
- usr_args: LLMPromptArgs,
82
- in_args: InT,
83
- batch_idx: int = 0,
84
- ctx: RunContextWrapper[CtxT] | None = None,
85
- ) -> LLMFormattedArgs:
86
- if self.format_in_args_impl:
87
- return self.format_in_args_impl(
88
- usr_args=usr_args, in_args=in_args, batch_idx=batch_idx, ctx=ctx
89
- )
63
+ def make_sys_prompt(
64
+ self, sys_args: LLMPromptArgs | None = None, ctx: RunContext[CtxT] | None = None
65
+ ) -> str | None:
66
+ if self.sys_prompt_template is None:
67
+ return None
90
68
 
91
- if not isinstance(in_args, BaseModel) and in_args is not None:
92
- raise TypeError(
93
- "Cannot apply default formatting to non-BaseModel received arguments."
94
- )
69
+ val_sys_args = sys_args
70
+ if sys_args is not None:
71
+ if self.sys_args_schema is not None:
72
+ val_sys_args = self.sys_args_schema.model_validate(sys_args)
73
+ else:
74
+ raise TypeError(
75
+ "System prompt template is set, but system arguments schema is not "
76
+ "provided."
77
+ )
95
78
 
96
- usr_args_ = usr_args
97
- in_args_ = DummySchema() if in_args is None else in_args
79
+ if self.make_sys_prompt_impl:
80
+ return self.make_sys_prompt_impl(sys_args=val_sys_args, ctx=ctx)
98
81
 
99
- usr_args_dump = usr_args_.model_dump(exclude_unset=True)
100
- in_args_dump = in_args_.model_dump(exclude={"selected_recipient_ids"})
82
+ sys_args_dict = (
83
+ val_sys_args.model_dump(exclude_unset=True) if val_sys_args else {}
84
+ )
101
85
 
102
- return usr_args_dump | in_args_dump
86
+ return self.sys_prompt_template.format(**sys_args_dict)
103
87
 
104
- def make_sys_prompt(
88
+ def make_in_content(
105
89
  self,
106
- sys_args: LLMPromptArgs,
107
90
  *,
108
- ctx: RunContextWrapper[CtxT] | None,
109
- ) -> LLMPrompt | None:
110
- if self.sys_prompt is None:
111
- return None
112
- val_sys_args = self.sys_args_schema.model_validate(sys_args)
113
- fmt_sys_args = self._format_sys_args(val_sys_args, ctx=ctx)
114
-
115
- return self.sys_prompt.format(**fmt_sys_args)
116
-
117
- def _usr_messages_from_text(self, text: str) -> list[UserMessage]:
118
- return [UserMessage.from_text(text, model_id=self._agent_id)]
119
-
120
- def _usr_messages_from_content_parts(
121
- self, content_parts: Sequence[str | ImageData]
122
- ) -> Sequence[UserMessage]:
123
- return [UserMessage.from_content_parts(content_parts, model_id=self._agent_id)]
91
+ in_args: InT_contra | None,
92
+ usr_args: LLMPromptArgs | None,
93
+ batch_idx: int = 0,
94
+ ctx: RunContext[CtxT] | None = None,
95
+ ) -> Content:
96
+ val_in_args, val_usr_args = self._validate_prompt_args(
97
+ in_args=in_args, usr_args=usr_args
98
+ )
124
99
 
125
- def _usr_messages_from_in_args(
126
- self, in_args_batch: Sequence[InT]
127
- ) -> Sequence[UserMessage]:
128
- return [
129
- UserMessage.from_text(
130
- self._in_args_type_adapter.dump_json(
131
- inp,
132
- exclude_unset=True,
133
- indent=2,
134
- exclude={"selected_recipient_ids"},
135
- warnings="error",
136
- ).decode("utf-8"),
137
- model_id=self._agent_id,
100
+ if self.make_in_content_impl:
101
+ return self.make_in_content_impl(
102
+ in_args=val_in_args, usr_args=val_usr_args, batch_idx=batch_idx, ctx=ctx
138
103
  )
139
- for inp in in_args_batch
140
- ]
141
104
 
142
- def _usr_messages_from_prompt_template(
143
- self,
144
- in_prompt: LLMPrompt,
145
- usr_args: UserRunArgs | None = None,
146
- in_args_batch: Sequence[InT] | None = None,
147
- ctx: RunContextWrapper[CtxT] | None = None,
148
- ) -> Sequence[UserMessage]:
149
- usr_args_batch_, in_args_batch_ = self._make_batched(usr_args, in_args_batch)
150
-
151
- val_usr_args_batch_ = [
152
- self.usr_args_schema.model_validate(u) for u in usr_args_batch_
153
- ]
154
- val_in_args_batch_ = [
155
- self._in_args_type_adapter.validate_python(inp) for inp in in_args_batch_
156
- ]
105
+ combined_args = self._combine_args(in_args=val_in_args, usr_args=val_usr_args)
106
+ if isinstance(combined_args, str):
107
+ return Content.from_text(combined_args)
157
108
 
158
- formatted_in_args_batch = [
159
- self._format_in_args(
160
- usr_args=val_usr_args, in_args=val_in_args, batch_idx=i, ctx=ctx
161
- )
162
- for i, (val_usr_args, val_in_args) in enumerate(
163
- zip(val_usr_args_batch_, val_in_args_batch_, strict=False)
109
+ if self.in_prompt_template is not None:
110
+ return Content.from_formatted_prompt(
111
+ self.in_prompt_template, prompt_args=combined_args
164
112
  )
165
- ]
166
113
 
167
- return [
168
- UserMessage.from_formatted_prompt(
169
- prompt_template=in_prompt, prompt_args=in_args
170
- )
171
- for in_args in formatted_in_args_batch
172
- ]
114
+ return Content.from_text(json.dumps(combined_args, indent=2))
173
115
 
174
116
  def make_user_messages(
175
117
  self,
176
118
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
177
- usr_args: UserRunArgs | None = None,
178
- in_args_batch: Sequence[InT] | None = None,
179
- entry_point: bool = False,
180
- ctx: RunContextWrapper[CtxT] | None = None,
119
+ in_args_batch: Sequence[InT_contra] | None = None,
120
+ usr_args: LLMPromptArgs | None = None,
121
+ ctx: RunContext[CtxT] | None = None,
181
122
  ) -> Sequence[UserMessage]:
182
- # 1) Direct user input (e.g. chat input)
183
- if chat_inputs is not None or entry_point:
184
- """
185
- * If chat inputs are provided, use them instead of the predefined
186
- input prompt template
187
- * In a multi-agent system, the predefined input prompt is used to
188
- construct agent inputs using the combination of received
189
- and user arguments.
190
- However, the first agent run (entry point) has no received
191
- messages, so we use the chat inputs directly, if provided.
192
- """
123
+ if chat_inputs:
193
124
  if isinstance(chat_inputs, LLMPrompt):
194
125
  return self._usr_messages_from_text(chat_inputs)
126
+ return self._usr_messages_from_content_parts(chat_inputs)
195
127
 
196
- if isinstance(chat_inputs, Sequence) and chat_inputs:
197
- return self._usr_messages_from_content_parts(chat_inputs)
128
+ in_content_batch = [
129
+ self.make_in_content(
130
+ in_args=in_args, usr_args=usr_args, batch_idx=i, ctx=ctx
131
+ )
132
+ for i, in_args in enumerate(in_args_batch or [None])
133
+ ]
134
+ return [
135
+ UserMessage(content=in_content, name=self._agent_name)
136
+ for in_content in in_content_batch
137
+ ]
198
138
 
199
- # 2) No input prompt template + received args → raw JSON messages
200
- if self.in_prompt is None and in_args_batch:
201
- return self._usr_messages_from_in_args(in_args_batch)
139
+ def _usr_messages_from_text(self, text: str) -> list[UserMessage]:
140
+ return [UserMessage.from_text(text, name=self._agent_name)]
202
141
 
203
- # 3) Input prompt template + any args → batch & format
204
- if self.in_prompt is not None:
205
- if in_args_batch and not isinstance(in_args_batch[0], BaseModel):
142
+ def _usr_messages_from_content_parts(
143
+ self, content_parts: Sequence[str | ImageData]
144
+ ) -> list[UserMessage]:
145
+ return [UserMessage.from_content_parts(content_parts, name=self._agent_name)]
146
+
147
+ def _validate_prompt_args(
148
+ self,
149
+ *,
150
+ in_args: InT_contra | None,
151
+ usr_args: LLMPromptArgs | None,
152
+ ) -> tuple[InT_contra | None, LLMPromptArgs | None]:
153
+ val_usr_args = usr_args
154
+ if usr_args is not None:
155
+ if self.in_prompt_template is None:
156
+ raise TypeError(
157
+ "Input prompt template is not set, but user arguments are provided."
158
+ )
159
+ if self.usr_args_schema is None:
160
+ raise TypeError(
161
+ "User arguments schema is not provided, but user arguments are "
162
+ "given."
163
+ )
164
+ val_usr_args = self.usr_args_schema.model_validate(usr_args)
165
+
166
+ val_in_args = in_args
167
+ if in_args is not None:
168
+ val_in_args = self._in_args_type_adapter.validate_python(in_args)
169
+ if isinstance(val_in_args, BaseModel):
170
+ _, has_image = self._format_pydantic_prompt_args(val_in_args)
171
+ if has_image and self.in_prompt_template is None:
172
+ raise TypeError(
173
+ "BaseModel input arguments contain ImageData, but input prompt "
174
+ "template is not set. Cannot format input arguments."
175
+ )
176
+ elif self.in_prompt_template is not None:
206
177
  raise TypeError(
207
178
  "Cannot use the input prompt template with "
208
- "non-BaseModel received arguments."
179
+ "non-BaseModel input arguments."
180
+ )
181
+
182
+ return val_in_args, val_usr_args
183
+
184
+ @staticmethod
185
+ def _format_pydantic_prompt_args(
186
+ inp: BaseModel,
187
+ ) -> tuple[dict[str, PromptArgumentType], bool]:
188
+ formatted_args: dict[str, PromptArgumentType] = {}
189
+ contains_image_data = False
190
+ for field in type(inp).model_fields:
191
+ if field == "selected_recipients":
192
+ continue
193
+
194
+ val = getattr(inp, field)
195
+ if isinstance(val, (int, str, bool)):
196
+ formatted_args[field] = val
197
+ elif isinstance(val, ImageData):
198
+ formatted_args[field] = val
199
+ contains_image_data = True
200
+ elif isinstance(val, BaseModel):
201
+ formatted_args[field] = val.model_dump_json(indent=2, warnings="error")
202
+ else:
203
+ raise TypeError(
204
+ f"Field '{field}' in prompt arguments must be of type "
205
+ "int, str, bool, BaseModel, or ImageData."
209
206
  )
210
- return self._usr_messages_from_prompt_template(
211
- in_prompt=self.in_prompt,
212
- usr_args=usr_args,
213
- in_args_batch=in_args_batch,
214
- ctx=ctx,
215
- )
216
207
 
217
- return []
208
+ return formatted_args, contains_image_data
218
209
 
219
- def _make_batched(
210
+ def _combine_args(
220
211
  self,
221
- usr_args: UserRunArgs | None = None,
222
- in_args_batch: Sequence[InT] | None = None,
223
- ) -> tuple[Sequence[LLMPromptArgs | DummySchema], Sequence[InT | DummySchema]]:
224
- usr_args_batch_ = (
225
- usr_args if isinstance(usr_args, list) else [usr_args or DummySchema()]
212
+ *,
213
+ in_args: InT_contra | None,
214
+ usr_args: LLMPromptArgs | None,
215
+ ) -> Mapping[str, PromptArgumentType] | str:
216
+ fmt_usr_args, _ = (
217
+ self._format_pydantic_prompt_args(usr_args) if usr_args else ({}, False)
226
218
  )
227
- in_args_batch_ = in_args_batch or [DummySchema()]
228
219
 
229
- # Broadcast singleton → match lengths
230
- if len(usr_args_batch_) == 1 and len(in_args_batch_) > 1:
231
- usr_args_batch_ = [deepcopy(usr_args_batch_[0]) for _ in in_args_batch_]
232
- if len(in_args_batch_) == 1 and len(usr_args_batch_) > 1:
233
- in_args_batch_ = [deepcopy(in_args_batch_[0]) for _ in usr_args_batch_]
234
- if len(usr_args_batch_) != len(in_args_batch_):
235
- raise ValueError("User args and received args must have the same length")
220
+ if in_args is None:
221
+ return fmt_usr_args
222
+
223
+ if isinstance(in_args, BaseModel):
224
+ fmt_in_args, _ = self._format_pydantic_prompt_args(in_args)
225
+ return fmt_in_args | fmt_usr_args
226
+
227
+ combined_args_str = self._in_args_type_adapter.dump_json(
228
+ in_args, indent=2, warnings="error"
229
+ ).decode("utf-8")
230
+ if usr_args is not None:
231
+ fmt_usr_args_str = usr_args.model_dump_json(indent=2, warnings="error")
232
+ combined_args_str += "\n" + fmt_usr_args_str
236
233
 
237
- return usr_args_batch_, in_args_batch_
234
+ return combined_args_str
@@ -1,61 +1,39 @@
1
- from collections.abc import Sequence
2
- from typing import Any, Generic, TypeAlias, TypeVar
1
+ from collections import defaultdict
2
+ from collections.abc import Mapping
3
+ from typing import Any, Generic, TypeVar
3
4
  from uuid import uuid4
4
5
 
5
6
  from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
6
7
 
7
- from .printer import Printer
8
- from .typing.content import ImageData
9
- from .typing.io import (
10
- AgentID,
11
- AgentState,
12
- InT,
13
- LLMPrompt,
14
- LLMPromptArgs,
15
- OutT,
16
- StateT,
17
- )
18
- from .usage_tracker import UsageTracker
8
+ from grasp_agents.typing.completion import Completion
19
9
 
20
- SystemRunArgs: TypeAlias = LLMPromptArgs
21
- UserRunArgs: TypeAlias = LLMPromptArgs | list[LLMPromptArgs]
10
+ from .printer import ColoringMode, Printer
11
+ from .typing.io import LLMPromptArgs, ProcName
12
+ from .usage_tracker import UsageTracker
22
13
 
23
14
 
24
15
  class RunArgs(BaseModel):
25
- sys: SystemRunArgs = Field(default_factory=LLMPromptArgs)
26
- usr: UserRunArgs = Field(default_factory=LLMPromptArgs)
16
+ sys: LLMPromptArgs = Field(default_factory=LLMPromptArgs)
17
+ usr: LLMPromptArgs = Field(default_factory=LLMPromptArgs)
27
18
 
28
19
  model_config = ConfigDict(extra="forbid")
29
20
 
30
21
 
31
- class InteractionRecord(BaseModel, Generic[InT, OutT, StateT]):
32
- source_id: str
33
- recipient_ids: Sequence[AgentID]
34
- state: StateT
35
- chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None
36
- sys_prompt: LLMPrompt | None = None
37
- in_prompt: LLMPrompt | None = None
38
- sys_args: SystemRunArgs | None = None
39
- usr_args: UserRunArgs | None = None
40
- in_args: Sequence[InT] | None = None
41
- outputs: Sequence[OutT]
42
-
43
- model_config = ConfigDict(extra="forbid", frozen=True)
44
-
45
-
46
- InteractionHistory: TypeAlias = list[InteractionRecord[Any, Any, AgentState]]
47
-
48
-
49
22
  CtxT = TypeVar("CtxT")
50
23
 
51
24
 
52
- class RunContextWrapper(BaseModel, Generic[CtxT]):
53
- context: CtxT | None = None
25
+ class RunContext(BaseModel, Generic[CtxT]):
54
26
  run_id: str = Field(default_factory=lambda: str(uuid4())[:8], frozen=True)
55
- run_args: dict[AgentID, RunArgs] = Field(default_factory=dict)
56
- interaction_history: InteractionHistory = Field(default_factory=list) # type: ignore[valid-type]
27
+
28
+ state: CtxT | None = None
29
+
30
+ run_args: dict[ProcName, RunArgs] = Field(default_factory=dict)
31
+ completions: Mapping[ProcName, list[Completion]] = Field(
32
+ default_factory=lambda: defaultdict(list)
33
+ )
57
34
 
58
35
  print_messages: bool = False
36
+ color_messages_by: ColoringMode = "role"
59
37
 
60
38
  _usage_tracker: UsageTracker = PrivateAttr()
61
39
  _printer: Printer = PrivateAttr()
@@ -63,7 +41,9 @@ class RunContextWrapper(BaseModel, Generic[CtxT]):
63
41
  def model_post_init(self, context: Any) -> None: # noqa: ARG002
64
42
  self._usage_tracker = UsageTracker(source_id=self.run_id)
65
43
  self._printer = Printer(
66
- source_id=self.run_id, print_messages=self.print_messages
44
+ source_id=self.run_id,
45
+ print_messages=self.print_messages,
46
+ color_by=self.color_messages_by,
67
47
  )
68
48
 
69
49
  @property
@@ -1,14 +1,58 @@
1
- from abc import ABC
1
+ import time
2
+ from typing import Literal, TypeAlias
3
+ from uuid import uuid4
2
4
 
3
- from pydantic import BaseModel
5
+ from openai.types.chat.chat_completion import ChoiceLogprobs as CompletionChoiceLogprobs
6
+ from pydantic import BaseModel, Field, NonNegativeFloat, NonNegativeInt
4
7
 
5
8
  from .message import AssistantMessage
6
9
 
10
+ FinishReason: TypeAlias = Literal[
11
+ "stop", "length", "tool_calls", "content_filter", "function_call"
12
+ ]
13
+
14
+
15
+ class Usage(BaseModel):
16
+ input_tokens: NonNegativeInt = 0
17
+ output_tokens: NonNegativeInt = 0
18
+ reasoning_tokens: NonNegativeInt | None = None
19
+ cached_tokens: NonNegativeInt | None = None
20
+ cost: NonNegativeFloat | None = None
21
+
22
+ def __add__(self, add_usage: "Usage") -> "Usage":
23
+ input_tokens = self.input_tokens + add_usage.input_tokens
24
+ output_tokens = self.output_tokens + add_usage.output_tokens
25
+ if self.reasoning_tokens is not None or add_usage.reasoning_tokens is not None:
26
+ reasoning_tokens = (self.reasoning_tokens or 0) + (
27
+ add_usage.reasoning_tokens or 0
28
+ )
29
+ else:
30
+ reasoning_tokens = None
31
+
32
+ if self.cached_tokens is not None or add_usage.cached_tokens is not None:
33
+ cached_tokens = (self.cached_tokens or 0) + (add_usage.cached_tokens or 0)
34
+ else:
35
+ cached_tokens = None
36
+
37
+ cost = (
38
+ (self.cost or 0.0) + add_usage.cost
39
+ if (add_usage.cost is not None)
40
+ else None
41
+ )
42
+ return Usage(
43
+ input_tokens=input_tokens,
44
+ output_tokens=output_tokens,
45
+ reasoning_tokens=reasoning_tokens,
46
+ cached_tokens=cached_tokens,
47
+ cost=cost,
48
+ )
49
+
7
50
 
8
51
  class CompletionChoice(BaseModel):
9
- # TODO: add fields
10
52
  message: AssistantMessage
11
- finish_reason: str | None
53
+ finish_reason: FinishReason | None
54
+ index: int
55
+ logprobs: CompletionChoiceLogprobs | None = None
12
56
 
13
57
 
14
58
  class CompletionError(BaseModel):
@@ -17,14 +61,16 @@ class CompletionError(BaseModel):
17
61
  code: int
18
62
 
19
63
 
20
- class Completion(BaseModel, ABC):
21
- # TODO: add fields
64
+ class Completion(BaseModel):
65
+ id: str = Field(default_factory=lambda: str(uuid4())[:8])
66
+ created: int = Field(default_factory=lambda: int(time.time()))
67
+ model: str
68
+ name: str | None = None
69
+ system_fingerprint: str | None = None
22
70
  choices: list[CompletionChoice]
23
- model_id: str | None = None
71
+ usage: Usage | None = None
24
72
  error: CompletionError | None = None
25
73
 
26
-
27
- class CompletionChunk(BaseModel):
28
- # TODO: add more fields and tool use support (and choices?)
29
- delta: str | None = None
30
- model_id: str | None = None
74
+ @property
75
+ def messages(self) -> list[AssistantMessage]:
76
+ return [choice.message for choice in self.choices if choice.message]