grasp_agents 0.2.11__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. grasp_agents/__init__.py +15 -14
  2. grasp_agents/cloud_llm.py +118 -131
  3. grasp_agents/comm_processor.py +201 -0
  4. grasp_agents/generics_utils.py +15 -7
  5. grasp_agents/llm.py +60 -31
  6. grasp_agents/llm_agent.py +229 -273
  7. grasp_agents/llm_agent_memory.py +58 -0
  8. grasp_agents/llm_policy_executor.py +482 -0
  9. grasp_agents/memory.py +20 -134
  10. grasp_agents/message_history.py +140 -0
  11. grasp_agents/openai/__init__.py +54 -36
  12. grasp_agents/openai/completion_chunk_converters.py +78 -0
  13. grasp_agents/openai/completion_converters.py +53 -30
  14. grasp_agents/openai/content_converters.py +13 -14
  15. grasp_agents/openai/converters.py +44 -68
  16. grasp_agents/openai/message_converters.py +58 -72
  17. grasp_agents/openai/openai_llm.py +101 -42
  18. grasp_agents/openai/tool_converters.py +24 -19
  19. grasp_agents/packet.py +24 -0
  20. grasp_agents/packet_pool.py +91 -0
  21. grasp_agents/printer.py +29 -15
  22. grasp_agents/processor.py +194 -0
  23. grasp_agents/prompt_builder.py +175 -192
  24. grasp_agents/run_context.py +20 -37
  25. grasp_agents/typing/completion.py +58 -12
  26. grasp_agents/typing/completion_chunk.py +173 -0
  27. grasp_agents/typing/converters.py +8 -12
  28. grasp_agents/typing/events.py +86 -0
  29. grasp_agents/typing/io.py +4 -13
  30. grasp_agents/typing/message.py +12 -50
  31. grasp_agents/typing/tool.py +52 -26
  32. grasp_agents/usage_tracker.py +6 -6
  33. grasp_agents/utils.py +3 -3
  34. grasp_agents/workflow/looped_workflow.py +132 -0
  35. grasp_agents/workflow/parallel_processor.py +95 -0
  36. grasp_agents/workflow/sequential_workflow.py +66 -0
  37. grasp_agents/workflow/workflow_processor.py +78 -0
  38. {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.1.dist-info}/METADATA +41 -50
  39. grasp_agents-0.3.1.dist-info/RECORD +51 -0
  40. grasp_agents/agent_message.py +0 -27
  41. grasp_agents/agent_message_pool.py +0 -92
  42. grasp_agents/base_agent.py +0 -51
  43. grasp_agents/comm_agent.py +0 -217
  44. grasp_agents/llm_agent_state.py +0 -79
  45. grasp_agents/tool_orchestrator.py +0 -203
  46. grasp_agents/workflow/looped_agent.py +0 -134
  47. grasp_agents/workflow/sequential_agent.py +0 -72
  48. grasp_agents/workflow/workflow_agent.py +0 -88
  49. grasp_agents-0.2.11.dist-info/RECORD +0 -46
  50. {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.1.dist-info}/WHEEL +0 -0
  51. {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.1.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,251 +1,234 @@
1
- from collections.abc import Sequence
2
- from copy import deepcopy
3
- from typing import ClassVar, Generic, Protocol, cast
1
+ import json
2
+ from collections.abc import Mapping, Sequence
3
+ from typing import ClassVar, Generic, Protocol, TypeAlias
4
4
 
5
5
  from pydantic import BaseModel, TypeAdapter
6
6
 
7
7
  from .generics_utils import AutoInstanceAttributesMixin
8
- from .run_context import CtxT, RunContextWrapper
9
- from .typing.content import ImageData
10
- from .typing.io import (
11
- InT,
12
- LLMFormattedArgs,
13
- LLMFormattedSystemArgs,
14
- LLMPrompt,
15
- LLMPromptArgs,
16
- )
8
+ from .run_context import CtxT, RunContext
9
+ from .typing.content import Content, ImageData
10
+ from .typing.io import InT_contra, LLMPrompt, LLMPromptArgs
17
11
  from .typing.message import UserMessage
18
12
 
19
13
 
20
- class DummySchema(BaseModel):
21
- pass
22
-
23
-
24
- class FormatSystemArgsHandler(Protocol[CtxT]):
14
+ class MakeSystemPromptHandler(Protocol[CtxT]):
25
15
  def __call__(
26
16
  self,
27
- sys_args: LLMPromptArgs,
17
+ sys_args: LLMPromptArgs | None,
28
18
  *,
29
- ctx: RunContextWrapper[CtxT] | None,
30
- ) -> LLMFormattedSystemArgs: ...
19
+ ctx: RunContext[CtxT] | None,
20
+ ) -> str: ...
31
21
 
32
22
 
33
- class FormatInputArgsHandler(Protocol[InT, CtxT]):
23
+ class MakeInputContentHandler(Protocol[InT_contra, CtxT]):
34
24
  def __call__(
35
25
  self,
36
26
  *,
37
- usr_args: LLMPromptArgs,
38
- in_args: InT,
27
+ in_args: InT_contra | None,
28
+ usr_args: LLMPromptArgs | None,
39
29
  batch_idx: int,
40
- ctx: RunContextWrapper[CtxT] | None,
41
- ) -> LLMFormattedArgs: ...
30
+ ctx: RunContext[CtxT] | None,
31
+ ) -> Content: ...
32
+
42
33
 
34
+ PromptArgumentType: TypeAlias = str | bool | int | ImageData
43
35
 
44
- class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
36
+
37
+ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
45
38
  _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {0: "_in_type"}
46
39
 
47
40
  def __init__(
48
41
  self,
49
- agent_id: str,
50
- sys_prompt: LLMPrompt | None,
51
- in_prompt: LLMPrompt | None,
52
- sys_args_schema: type[LLMPromptArgs],
53
- usr_args_schema: type[LLMPromptArgs],
42
+ agent_name: str,
43
+ sys_prompt_template: LLMPrompt | None,
44
+ in_prompt_template: LLMPrompt | None,
45
+ sys_args_schema: type[LLMPromptArgs] | None = None,
46
+ usr_args_schema: type[LLMPromptArgs] | None = None,
54
47
  ):
55
- self._in_type: type[InT]
48
+ self._in_type: type[InT_contra]
56
49
  super().__init__()
57
50
 
58
- self._agent_id = agent_id
59
- self.sys_prompt = sys_prompt
60
- self.in_prompt = in_prompt
51
+ self._agent_name = agent_name
52
+ self.sys_prompt_template = sys_prompt_template
53
+ self.in_prompt_template = in_prompt_template
61
54
  self.sys_args_schema = sys_args_schema
62
55
  self.usr_args_schema = usr_args_schema
63
- self.format_sys_args_impl: FormatSystemArgsHandler[CtxT] | None = None
64
- self.format_in_args_impl: FormatInputArgsHandler[InT, CtxT] | None = None
65
-
66
- self._in_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
67
-
68
- def _format_sys_args(
69
- self,
70
- sys_args: LLMPromptArgs,
71
- ctx: RunContextWrapper[CtxT] | None = None,
72
- ) -> LLMFormattedSystemArgs:
73
- if self.format_sys_args_impl:
74
- return self.format_sys_args_impl(sys_args=sys_args, ctx=ctx)
75
-
76
- return sys_args.model_dump(exclude_unset=True)
77
-
78
- def _format_in_args(
79
- self,
80
- *,
81
- usr_args: LLMPromptArgs,
82
- in_args: InT,
83
- batch_idx: int = 0,
84
- ctx: RunContextWrapper[CtxT] | None = None,
85
- ) -> LLMFormattedArgs:
86
- if self.format_in_args_impl:
87
- return self.format_in_args_impl(
88
- usr_args=usr_args, in_args=in_args, batch_idx=batch_idx, ctx=ctx
89
- )
90
-
91
- if not isinstance(in_args, BaseModel) and in_args is not None:
92
- raise TypeError(
93
- "Cannot apply default formatting to non-BaseModel received arguments."
94
- )
95
-
96
- usr_args_ = usr_args
97
- in_args_ = DummySchema() if in_args is None else in_args
98
-
99
- usr_args_dump = usr_args_.model_dump(exclude_unset=True)
100
- in_args_dump = in_args_.model_dump(exclude={"selected_recipient_ids"})
56
+ self.make_sys_prompt_impl: MakeSystemPromptHandler[CtxT] | None = None
57
+ self.make_in_content_impl: MakeInputContentHandler[InT_contra, CtxT] | None = (
58
+ None
59
+ )
101
60
 
102
- return usr_args_dump | in_args_dump
61
+ self._in_args_type_adapter: TypeAdapter[InT_contra] = TypeAdapter(self._in_type)
103
62
 
104
63
  def make_sys_prompt(
105
- self,
106
- sys_args: LLMPromptArgs,
107
- *,
108
- ctx: RunContextWrapper[CtxT] | None,
109
- ) -> LLMPrompt | None:
110
- if self.sys_prompt is None:
64
+ self, sys_args: LLMPromptArgs | None = None, ctx: RunContext[CtxT] | None = None
65
+ ) -> str | None:
66
+ if self.sys_prompt_template is None:
111
67
  return None
112
- val_sys_args = self.sys_args_schema.model_validate(sys_args)
113
- fmt_sys_args = self._format_sys_args(val_sys_args, ctx=ctx)
114
68
 
115
- return self.sys_prompt.format(**fmt_sys_args)
69
+ val_sys_args = sys_args
70
+ if sys_args is not None:
71
+ if self.sys_args_schema is not None:
72
+ val_sys_args = self.sys_args_schema.model_validate(sys_args)
73
+ else:
74
+ raise TypeError(
75
+ "System prompt template is set, but system arguments schema is not "
76
+ "provided."
77
+ )
116
78
 
117
- def _usr_messages_from_text(self, text: str) -> list[UserMessage]:
118
- return [UserMessage.from_text(text, model_id=self._agent_id)]
79
+ if self.make_sys_prompt_impl:
80
+ return self.make_sys_prompt_impl(sys_args=val_sys_args, ctx=ctx)
119
81
 
120
- def _usr_messages_from_content_parts(
121
- self, content_parts: Sequence[str | ImageData]
122
- ) -> Sequence[UserMessage]:
123
- return [UserMessage.from_content_parts(content_parts, model_id=self._agent_id)]
82
+ sys_args_dict = (
83
+ val_sys_args.model_dump(exclude_unset=True) if val_sys_args else {}
84
+ )
124
85
 
125
- def _usr_messages_from_in_args(
126
- self, in_args_batch: Sequence[InT]
127
- ) -> Sequence[UserMessage]:
128
- return [
129
- UserMessage.from_text(
130
- self._in_args_type_adapter.dump_json(
131
- inp,
132
- exclude_unset=True,
133
- indent=2,
134
- exclude={"selected_recipient_ids"},
135
- warnings="error",
136
- ).decode("utf-8"),
137
- model_id=self._agent_id,
138
- )
139
- for inp in in_args_batch
140
- ]
86
+ return self.sys_prompt_template.format(**sys_args_dict)
141
87
 
142
- def _usr_messages_from_prompt_template(
88
+ def make_in_content(
143
89
  self,
144
- in_prompt: LLMPrompt,
145
- in_args_batch: Sequence[InT] | None = None,
146
- usr_args_batch: Sequence[LLMPromptArgs] | None = None,
147
- ctx: RunContextWrapper[CtxT] | None = None,
148
- ) -> Sequence[UserMessage]:
149
- usr_args_batch_, in_args_batch_ = self._align_in_usr_batches(
150
- in_args_batch, usr_args_batch
90
+ *,
91
+ in_args: InT_contra | None,
92
+ usr_args: LLMPromptArgs | None,
93
+ batch_idx: int = 0,
94
+ ctx: RunContext[CtxT] | None = None,
95
+ ) -> Content:
96
+ val_in_args, val_usr_args = self._validate_prompt_args(
97
+ in_args=in_args, usr_args=usr_args
151
98
  )
152
99
 
153
- val_usr_args_batch_ = [
154
- self.usr_args_schema.model_validate(u) for u in usr_args_batch_
155
- ]
156
- val_in_args_batch_ = [
157
- self._in_args_type_adapter.validate_python(inp) for inp in in_args_batch_
158
- ]
159
-
160
- formatted_in_args_batch = [
161
- self._format_in_args(
162
- usr_args=val_usr_args, in_args=val_in_args, batch_idx=i, ctx=ctx
100
+ if self.make_in_content_impl:
101
+ return self.make_in_content_impl(
102
+ in_args=val_in_args, usr_args=val_usr_args, batch_idx=batch_idx, ctx=ctx
163
103
  )
164
- for i, (val_usr_args, val_in_args) in enumerate(
165
- zip(val_usr_args_batch_, val_in_args_batch_, strict=False)
166
- )
167
- ]
168
104
 
169
- return [
170
- UserMessage.from_formatted_prompt(
171
- prompt_template=in_prompt, prompt_args=in_args
105
+ combined_args = self._combine_args(in_args=val_in_args, usr_args=val_usr_args)
106
+ if isinstance(combined_args, str):
107
+ return Content.from_text(combined_args)
108
+
109
+ if self.in_prompt_template is not None:
110
+ return Content.from_formatted_prompt(
111
+ self.in_prompt_template, prompt_args=combined_args
172
112
  )
173
- for in_args in formatted_in_args_batch
174
- ]
113
+
114
+ return Content.from_text(json.dumps(combined_args, indent=2))
175
115
 
176
116
  def make_user_messages(
177
117
  self,
178
118
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
179
- in_args: InT | Sequence[InT] | None = None,
180
- usr_args: LLMPromptArgs | Sequence[LLMPromptArgs] | None = None,
181
- entry_point: bool = False,
182
- ctx: RunContextWrapper[CtxT] | None = None,
119
+ in_args_batch: Sequence[InT_contra] | None = None,
120
+ usr_args: LLMPromptArgs | None = None,
121
+ ctx: RunContext[CtxT] | None = None,
183
122
  ) -> Sequence[UserMessage]:
184
- # 1) Direct user input (e.g. chat input)
185
- if chat_inputs is not None or entry_point:
186
- """
187
- * If chat inputs are provided, use them instead of the predefined
188
- input prompt template
189
- * In a multi-agent system, the predefined input prompt is used to
190
- construct agent inputs using the combination of received
191
- and user arguments.
192
- However, the first agent run (entry point) has no received
193
- messages, so we use the chat inputs directly, if provided.
194
- """
123
+ if chat_inputs:
195
124
  if isinstance(chat_inputs, LLMPrompt):
196
125
  return self._usr_messages_from_text(chat_inputs)
126
+ return self._usr_messages_from_content_parts(chat_inputs)
197
127
 
198
- if isinstance(chat_inputs, Sequence) and chat_inputs:
199
- return self._usr_messages_from_content_parts(chat_inputs)
128
+ in_content_batch = [
129
+ self.make_in_content(
130
+ in_args=in_args, usr_args=usr_args, batch_idx=i, ctx=ctx
131
+ )
132
+ for i, in_args in enumerate(in_args_batch or [None])
133
+ ]
134
+ return [
135
+ UserMessage(content=in_content, name=self._agent_name)
136
+ for in_content in in_content_batch
137
+ ]
200
138
 
201
- in_args_batch = cast(
202
- "Sequence[InT] | None",
203
- in_args if (isinstance(in_args, Sequence) or not in_args) else [in_args],
204
- )
139
+ def _usr_messages_from_text(self, text: str) -> list[UserMessage]:
140
+ return [UserMessage.from_text(text, name=self._agent_name)]
205
141
 
206
- # 2) No input prompt template + received args → raw JSON messages
207
- if self.in_prompt is None and in_args_batch:
208
- return self._usr_messages_from_in_args(in_args_batch)
209
-
210
- usr_args_batch = cast(
211
- "Sequence[LLMPromptArgs] | None",
212
- (
213
- usr_args
214
- if (isinstance(usr_args, Sequence) or not usr_args)
215
- else [usr_args]
216
- ),
217
- )
142
+ def _usr_messages_from_content_parts(
143
+ self, content_parts: Sequence[str | ImageData]
144
+ ) -> list[UserMessage]:
145
+ return [UserMessage.from_content_parts(content_parts, name=self._agent_name)]
218
146
 
219
- # 3) Input prompt template + any args → batch & format
220
- if self.in_prompt is not None:
221
- if in_args_batch and not isinstance(in_args_batch[0], BaseModel):
147
+ def _validate_prompt_args(
148
+ self,
149
+ *,
150
+ in_args: InT_contra | None,
151
+ usr_args: LLMPromptArgs | None,
152
+ ) -> tuple[InT_contra | None, LLMPromptArgs | None]:
153
+ val_usr_args = usr_args
154
+ if usr_args is not None:
155
+ if self.in_prompt_template is None:
156
+ raise TypeError(
157
+ "Input prompt template is not set, but user arguments are provided."
158
+ )
159
+ if self.usr_args_schema is None:
160
+ raise TypeError(
161
+ "User arguments schema is not provided, but user arguments are "
162
+ "given."
163
+ )
164
+ val_usr_args = self.usr_args_schema.model_validate(usr_args)
165
+
166
+ val_in_args = in_args
167
+ if in_args is not None:
168
+ val_in_args = self._in_args_type_adapter.validate_python(in_args)
169
+ if isinstance(val_in_args, BaseModel):
170
+ _, has_image = self._format_pydantic_prompt_args(val_in_args)
171
+ if has_image and self.in_prompt_template is None:
172
+ raise TypeError(
173
+ "BaseModel input arguments contain ImageData, but input prompt "
174
+ "template is not set. Cannot format input arguments."
175
+ )
176
+ elif self.in_prompt_template is not None:
222
177
  raise TypeError(
223
178
  "Cannot use the input prompt template with "
224
- "non-BaseModel received arguments."
179
+ "non-BaseModel input arguments."
225
180
  )
226
- return self._usr_messages_from_prompt_template(
227
- in_prompt=self.in_prompt,
228
- usr_args_batch=usr_args_batch,
229
- in_args_batch=in_args_batch,
230
- ctx=ctx,
231
- )
232
181
 
233
- return []
182
+ return val_in_args, val_usr_args
183
+
184
+ @staticmethod
185
+ def _format_pydantic_prompt_args(
186
+ inp: BaseModel,
187
+ ) -> tuple[dict[str, PromptArgumentType], bool]:
188
+ formatted_args: dict[str, PromptArgumentType] = {}
189
+ contains_image_data = False
190
+ for field in type(inp).model_fields:
191
+ if field == "selected_recipients":
192
+ continue
193
+
194
+ val = getattr(inp, field)
195
+ if isinstance(val, (int, str, bool)):
196
+ formatted_args[field] = val
197
+ elif isinstance(val, ImageData):
198
+ formatted_args[field] = val
199
+ contains_image_data = True
200
+ elif isinstance(val, BaseModel):
201
+ formatted_args[field] = val.model_dump_json(indent=2, warnings="error")
202
+ else:
203
+ raise TypeError(
204
+ f"Field '{field}' in prompt arguments must be of type "
205
+ "int, str, bool, BaseModel, or ImageData."
206
+ )
207
+
208
+ return formatted_args, contains_image_data
234
209
 
235
- def _align_in_usr_batches(
210
+ def _combine_args(
236
211
  self,
237
- in_args_batch: Sequence[InT] | None = None,
238
- usr_args_batch: Sequence[LLMPromptArgs] | None = None,
239
- ) -> tuple[Sequence[LLMPromptArgs | DummySchema], Sequence[InT | DummySchema]]:
240
- usr_args_batch_ = usr_args_batch or [DummySchema()]
241
- in_args_batch_ = in_args_batch or [DummySchema()]
242
-
243
- # Broadcast singleton → match lengths
244
- if len(usr_args_batch_) == 1 and len(in_args_batch_) > 1:
245
- usr_args_batch_ = [deepcopy(usr_args_batch_[0]) for _ in in_args_batch_]
246
- if len(in_args_batch_) == 1 and len(usr_args_batch_) > 1:
247
- in_args_batch_ = [deepcopy(in_args_batch_[0]) for _ in usr_args_batch_]
248
- if len(usr_args_batch_) != len(in_args_batch_):
249
- raise ValueError("User args and received args must have the same length")
250
-
251
- return usr_args_batch_, in_args_batch_
212
+ *,
213
+ in_args: InT_contra | None,
214
+ usr_args: LLMPromptArgs | None,
215
+ ) -> Mapping[str, PromptArgumentType] | str:
216
+ fmt_usr_args, _ = (
217
+ self._format_pydantic_prompt_args(usr_args) if usr_args else ({}, False)
218
+ )
219
+
220
+ if in_args is None:
221
+ return fmt_usr_args
222
+
223
+ if isinstance(in_args, BaseModel):
224
+ fmt_in_args, _ = self._format_pydantic_prompt_args(in_args)
225
+ return fmt_in_args | fmt_usr_args
226
+
227
+ combined_args_str = self._in_args_type_adapter.dump_json(
228
+ in_args, indent=2, warnings="error"
229
+ ).decode("utf-8")
230
+ if usr_args is not None:
231
+ fmt_usr_args_str = usr_args.model_dump_json(indent=2, warnings="error")
232
+ combined_args_str += "\n" + fmt_usr_args_str
233
+
234
+ return combined_args_str
@@ -1,58 +1,39 @@
1
- from collections.abc import Sequence
2
- from typing import Any, Generic, TypeAlias, TypeVar
1
+ from collections import defaultdict
2
+ from collections.abc import Mapping
3
+ from typing import Any, Generic, TypeVar
3
4
  from uuid import uuid4
4
5
 
5
6
  from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
6
7
 
7
- from .printer import Printer
8
- from .typing.content import ImageData
9
- from .typing.io import (
10
- AgentID,
11
- AgentState,
12
- InT,
13
- LLMPrompt,
14
- LLMPromptArgs,
15
- OutT,
16
- StateT,
17
- )
8
+ from grasp_agents.typing.completion import Completion
9
+
10
+ from .printer import ColoringMode, Printer
11
+ from .typing.io import LLMPromptArgs, ProcName
18
12
  from .usage_tracker import UsageTracker
19
13
 
20
14
 
21
15
  class RunArgs(BaseModel):
22
16
  sys: LLMPromptArgs = Field(default_factory=LLMPromptArgs)
23
- usr: LLMPromptArgs | Sequence[LLMPromptArgs] = Field(default_factory=LLMPromptArgs)
17
+ usr: LLMPromptArgs = Field(default_factory=LLMPromptArgs)
24
18
 
25
19
  model_config = ConfigDict(extra="forbid")
26
20
 
27
21
 
28
- class InteractionRecord(BaseModel, Generic[InT, OutT, StateT]):
29
- source_id: str
30
- recipient_ids: Sequence[AgentID]
31
- state: StateT
32
- chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None
33
- sys_prompt: LLMPrompt | None = None
34
- in_prompt: LLMPrompt | None = None
35
- sys_args: LLMPromptArgs | None = None
36
- usr_args: LLMPromptArgs | Sequence[LLMPromptArgs] | None = None
37
- in_args: InT | Sequence[InT] | None = None
38
- outputs: Sequence[OutT]
39
-
40
- model_config = ConfigDict(extra="forbid", frozen=True)
41
-
42
-
43
- InteractionHistory: TypeAlias = list[InteractionRecord[Any, Any, AgentState]]
44
-
45
-
46
22
  CtxT = TypeVar("CtxT")
47
23
 
48
24
 
49
- class RunContextWrapper(BaseModel, Generic[CtxT]):
50
- context: CtxT | None = None
25
+ class RunContext(BaseModel, Generic[CtxT]):
51
26
  run_id: str = Field(default_factory=lambda: str(uuid4())[:8], frozen=True)
52
- run_args: dict[AgentID, RunArgs] = Field(default_factory=dict)
53
- interaction_history: InteractionHistory = Field(default_factory=list) # type: ignore[valid-type]
27
+
28
+ state: CtxT | None = None
29
+
30
+ run_args: dict[ProcName, RunArgs] = Field(default_factory=dict)
31
+ completions: Mapping[ProcName, list[Completion]] = Field(
32
+ default_factory=lambda: defaultdict(list)
33
+ )
54
34
 
55
35
  print_messages: bool = False
36
+ color_messages_by: ColoringMode = "role"
56
37
 
57
38
  _usage_tracker: UsageTracker = PrivateAttr()
58
39
  _printer: Printer = PrivateAttr()
@@ -60,7 +41,9 @@ class RunContextWrapper(BaseModel, Generic[CtxT]):
60
41
  def model_post_init(self, context: Any) -> None: # noqa: ARG002
61
42
  self._usage_tracker = UsageTracker(source_id=self.run_id)
62
43
  self._printer = Printer(
63
- source_id=self.run_id, print_messages=self.print_messages
44
+ source_id=self.run_id,
45
+ print_messages=self.print_messages,
46
+ color_by=self.color_messages_by,
64
47
  )
65
48
 
66
49
  @property
@@ -1,14 +1,58 @@
1
- from abc import ABC
1
+ import time
2
+ from typing import Literal, TypeAlias
3
+ from uuid import uuid4
2
4
 
3
- from pydantic import BaseModel
5
+ from openai.types.chat.chat_completion import ChoiceLogprobs as CompletionChoiceLogprobs
6
+ from pydantic import BaseModel, Field, NonNegativeFloat, NonNegativeInt
4
7
 
5
8
  from .message import AssistantMessage
6
9
 
10
+ FinishReason: TypeAlias = Literal[
11
+ "stop", "length", "tool_calls", "content_filter", "function_call"
12
+ ]
13
+
14
+
15
+ class Usage(BaseModel):
16
+ input_tokens: NonNegativeInt = 0
17
+ output_tokens: NonNegativeInt = 0
18
+ reasoning_tokens: NonNegativeInt | None = None
19
+ cached_tokens: NonNegativeInt | None = None
20
+ cost: NonNegativeFloat | None = None
21
+
22
+ def __add__(self, add_usage: "Usage") -> "Usage":
23
+ input_tokens = self.input_tokens + add_usage.input_tokens
24
+ output_tokens = self.output_tokens + add_usage.output_tokens
25
+ if self.reasoning_tokens is not None or add_usage.reasoning_tokens is not None:
26
+ reasoning_tokens = (self.reasoning_tokens or 0) + (
27
+ add_usage.reasoning_tokens or 0
28
+ )
29
+ else:
30
+ reasoning_tokens = None
31
+
32
+ if self.cached_tokens is not None or add_usage.cached_tokens is not None:
33
+ cached_tokens = (self.cached_tokens or 0) + (add_usage.cached_tokens or 0)
34
+ else:
35
+ cached_tokens = None
36
+
37
+ cost = (
38
+ (self.cost or 0.0) + add_usage.cost
39
+ if (add_usage.cost is not None)
40
+ else None
41
+ )
42
+ return Usage(
43
+ input_tokens=input_tokens,
44
+ output_tokens=output_tokens,
45
+ reasoning_tokens=reasoning_tokens,
46
+ cached_tokens=cached_tokens,
47
+ cost=cost,
48
+ )
49
+
7
50
 
8
51
  class CompletionChoice(BaseModel):
9
- # TODO: add fields
10
52
  message: AssistantMessage
11
- finish_reason: str | None
53
+ finish_reason: FinishReason | None
54
+ index: int
55
+ logprobs: CompletionChoiceLogprobs | None = None
12
56
 
13
57
 
14
58
  class CompletionError(BaseModel):
@@ -17,14 +61,16 @@ class CompletionError(BaseModel):
17
61
  code: int
18
62
 
19
63
 
20
- class Completion(BaseModel, ABC):
21
- # TODO: add fields
64
+ class Completion(BaseModel):
65
+ id: str = Field(default_factory=lambda: str(uuid4())[:8])
66
+ created: int = Field(default_factory=lambda: int(time.time()))
67
+ model: str
68
+ name: str | None = None
69
+ system_fingerprint: str | None = None
22
70
  choices: list[CompletionChoice]
23
- model_id: str | None = None
71
+ usage: Usage | None = None
24
72
  error: CompletionError | None = None
25
73
 
26
-
27
- class CompletionChunk(BaseModel):
28
- # TODO: add more fields and tool use support (and choices?)
29
- delta: str | None = None
30
- model_id: str | None = None
74
+ @property
75
+ def messages(self) -> list[AssistantMessage]:
76
+ return [choice.message for choice in self.choices if choice.message]