grasp_agents 0.2.11__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. grasp_agents/__init__.py +15 -14
  2. grasp_agents/cloud_llm.py +118 -131
  3. grasp_agents/comm_processor.py +201 -0
  4. grasp_agents/generics_utils.py +15 -7
  5. grasp_agents/llm.py +60 -31
  6. grasp_agents/llm_agent.py +229 -273
  7. grasp_agents/llm_agent_memory.py +58 -0
  8. grasp_agents/llm_policy_executor.py +482 -0
  9. grasp_agents/memory.py +20 -134
  10. grasp_agents/message_history.py +140 -0
  11. grasp_agents/openai/__init__.py +54 -36
  12. grasp_agents/openai/completion_chunk_converters.py +78 -0
  13. grasp_agents/openai/completion_converters.py +53 -30
  14. grasp_agents/openai/content_converters.py +13 -14
  15. grasp_agents/openai/converters.py +44 -68
  16. grasp_agents/openai/message_converters.py +58 -72
  17. grasp_agents/openai/openai_llm.py +101 -42
  18. grasp_agents/openai/tool_converters.py +24 -19
  19. grasp_agents/packet.py +24 -0
  20. grasp_agents/packet_pool.py +91 -0
  21. grasp_agents/printer.py +29 -15
  22. grasp_agents/processor.py +193 -0
  23. grasp_agents/prompt_builder.py +175 -192
  24. grasp_agents/run_context.py +20 -37
  25. grasp_agents/typing/completion.py +58 -12
  26. grasp_agents/typing/completion_chunk.py +173 -0
  27. grasp_agents/typing/converters.py +8 -12
  28. grasp_agents/typing/events.py +86 -0
  29. grasp_agents/typing/io.py +4 -13
  30. grasp_agents/typing/message.py +12 -50
  31. grasp_agents/typing/tool.py +52 -26
  32. grasp_agents/usage_tracker.py +6 -6
  33. grasp_agents/utils.py +3 -3
  34. grasp_agents/workflow/looped_workflow.py +132 -0
  35. grasp_agents/workflow/parallel_processor.py +95 -0
  36. grasp_agents/workflow/sequential_workflow.py +66 -0
  37. grasp_agents/workflow/workflow_processor.py +78 -0
  38. {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.2.dist-info}/METADATA +41 -50
  39. grasp_agents-0.3.2.dist-info/RECORD +51 -0
  40. grasp_agents/agent_message.py +0 -27
  41. grasp_agents/agent_message_pool.py +0 -92
  42. grasp_agents/base_agent.py +0 -51
  43. grasp_agents/comm_agent.py +0 -217
  44. grasp_agents/llm_agent_state.py +0 -79
  45. grasp_agents/tool_orchestrator.py +0 -203
  46. grasp_agents/workflow/looped_agent.py +0 -134
  47. grasp_agents/workflow/sequential_agent.py +0 -72
  48. grasp_agents/workflow/workflow_agent.py +0 -88
  49. grasp_agents-0.2.11.dist-info/RECORD +0 -46
  50. {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.2.dist-info}/WHEEL +0 -0
  51. {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.2.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,173 @@
1
+ import time
2
+ from collections import defaultdict
3
+ from collections.abc import Sequence
4
+ from uuid import uuid4
5
+
6
+ from openai.types.chat.chat_completion_chunk import (
7
+ ChoiceLogprobs as CompletionChunkChoiceLogprobs,
8
+ )
9
+ from openai.types.chat.chat_completion_token_logprob import (
10
+ ChatCompletionTokenLogprob as CompletionTokenLogprob,
11
+ )
12
+ from pydantic import BaseModel, Field
13
+
14
+ from .completion import (
15
+ Completion,
16
+ CompletionChoice,
17
+ CompletionChoiceLogprobs,
18
+ FinishReason,
19
+ Usage,
20
+ )
21
+ from .message import AssistantMessage, ToolCall
22
+
23
+
24
+ class CompletionChunkDeltaToolCall(BaseModel):
25
+ id: str | None
26
+ index: int
27
+ tool_name: str | None
28
+ tool_arguments: str | None
29
+
30
+
31
+ class CompletionChunkChoiceDelta(BaseModel):
32
+ content: str | None = None
33
+ refusal: str | None
34
+ role: str | None
35
+ tool_calls: list[CompletionChunkDeltaToolCall] | None
36
+
37
+
38
+ class CompletionChunkChoice(BaseModel):
39
+ delta: CompletionChunkChoiceDelta
40
+ finish_reason: FinishReason | None
41
+ index: int
42
+ logprobs: CompletionChunkChoiceLogprobs | None = None
43
+
44
+
45
+ class CompletionChunk(BaseModel):
46
+ id: str = Field(default_factory=lambda: str(uuid4())[:8])
47
+ created: int = Field(default_factory=lambda: int(time.time()))
48
+ model: str
49
+ name: str | None = None
50
+ system_fingerprint: str | None = None
51
+ choices: list[CompletionChunkChoice]
52
+ usage: Usage | None = None
53
+
54
+
55
+ def combine_completion_chunks(chunks: list[CompletionChunk]) -> Completion:
56
+ if not chunks:
57
+ raise ValueError("Cannot combine an empty list of completion chunks.")
58
+
59
+ model_list = {chunk.model for chunk in chunks}
60
+ if len(model_list) > 1:
61
+ raise ValueError("All chunks must have the same model.")
62
+ model = model_list.pop()
63
+
64
+ name_list = {chunk.name for chunk in chunks}
65
+ if len(name_list) > 1:
66
+ raise ValueError("All chunks must have the same name.")
67
+ name = name_list.pop()
68
+
69
+ system_fingerprints_list = {chunk.system_fingerprint for chunk in chunks}
70
+ if len(system_fingerprints_list) > 1:
71
+ raise ValueError("All chunks must have the same system fingerprint.")
72
+ system_fingerprint = system_fingerprints_list.pop()
73
+
74
+ created_list = [chunk.created for chunk in chunks]
75
+ created = max(created_list)
76
+
77
+ # Usage is found in the last completion chunk if requested
78
+ usage = chunks[-1].usage
79
+
80
+ logp_contents_per_choice: defaultdict[int, list[CompletionTokenLogprob]] = (
81
+ defaultdict(list)
82
+ )
83
+ logp_refusals_per_choice: defaultdict[int, list[CompletionTokenLogprob]] = (
84
+ defaultdict(list)
85
+ )
86
+ logprobs_per_choice: defaultdict[int, CompletionChoiceLogprobs | None] = (
87
+ defaultdict(lambda: None)
88
+ )
89
+
90
+ finish_reasons_per_choice: defaultdict[int, FinishReason | None] = defaultdict(
91
+ lambda: None
92
+ )
93
+
94
+ contents_per_choice: defaultdict[int, str] = defaultdict(lambda: "")
95
+ refusals_per_choice: defaultdict[int, str] = defaultdict(lambda: "")
96
+
97
+ tool_calls_per_choice: defaultdict[
98
+ int, Sequence[CompletionChunkDeltaToolCall] | None
99
+ ] = defaultdict(lambda: None)
100
+
101
+ messages_per_choice: dict[int, AssistantMessage] = {}
102
+
103
+ for chunk in chunks:
104
+ for choice in chunk.choices:
105
+ index = choice.index
106
+
107
+ # Concatenate content and refusal tokens for each choice
108
+ contents_per_choice[index] += choice.delta.content or ""
109
+ refusals_per_choice[index] += choice.delta.refusal or ""
110
+
111
+ # Concatenate logprobs for content and refusal tokens for each choice
112
+ if choice.logprobs is not None:
113
+ logp_contents_per_choice[index].extend(choice.logprobs.content or [])
114
+ logp_refusals_per_choice[index].extend(choice.logprobs.refusal or [])
115
+
116
+ # Take the last finish reason for each choice
117
+ finish_reasons_per_choice[index] = choice.finish_reason
118
+
119
+ # Tool calls should be in the last chunk for each choice
120
+ tool_calls_per_choice[index] = choice.delta.tool_calls
121
+
122
+ for index in finish_reasons_per_choice:
123
+ tool_calls: list[ToolCall] = []
124
+ if tool_calls_per_choice[index] is not None:
125
+ for _tool_call in tool_calls_per_choice[index]: # type: ignore
126
+ if (
127
+ _tool_call.id is None
128
+ or _tool_call.tool_name is None
129
+ or _tool_call.tool_arguments is None
130
+ ):
131
+ raise ValueError(
132
+ "Completion chunk tool calls must have id, tool_name, "
133
+ "and tool_arguments set."
134
+ )
135
+ tool_calls.append(
136
+ ToolCall(
137
+ id=_tool_call.id,
138
+ tool_name=_tool_call.tool_name,
139
+ tool_arguments=_tool_call.tool_arguments,
140
+ )
141
+ )
142
+
143
+ messages_per_choice[index] = AssistantMessage(
144
+ name=name,
145
+ content=contents_per_choice[index] or "<empty>",
146
+ refusal=(refusals_per_choice[index] or None),
147
+ tool_calls=(tool_calls or None),
148
+ )
149
+
150
+ if logp_contents_per_choice[index] or logp_refusals_per_choice[index]:
151
+ logprobs_per_choice[index] = CompletionChoiceLogprobs(
152
+ content=logp_contents_per_choice[index],
153
+ refusal=logp_refusals_per_choice[index],
154
+ )
155
+
156
+ choices = [
157
+ CompletionChoice(
158
+ index=index,
159
+ message=message,
160
+ finish_reason=finish_reasons_per_choice[index],
161
+ logprobs=logprobs_per_choice[index],
162
+ )
163
+ for index, message in messages_per_choice.items()
164
+ ]
165
+
166
+ return Completion(
167
+ model=model,
168
+ name=name,
169
+ created=created,
170
+ system_fingerprint=system_fingerprint,
171
+ choices=choices,
172
+ usage=usage,
173
+ )
@@ -1,10 +1,10 @@
1
1
  from abc import ABC, abstractmethod
2
- from collections.abc import AsyncIterator
3
2
  from typing import Any
4
3
 
5
4
  from pydantic import BaseModel
6
5
 
7
- from .completion import Completion, CompletionChunk
6
+ from .completion import Completion, Usage
7
+ from .completion_chunk import CompletionChunk
8
8
  from .content import Content
9
9
  from .message import AssistantMessage, Message, SystemMessage, ToolMessage, UserMessage
10
10
  from .tool import BaseTool, ToolChoice
@@ -38,9 +38,12 @@ class Converters(ABC):
38
38
 
39
39
  @staticmethod
40
40
  @abstractmethod
41
- def from_assistant_message(
42
- raw_message: Any, raw_usage: Any, **kwargs: Any
43
- ) -> AssistantMessage:
41
+ def from_completion_usage(raw_usage: Any, **kwargs: Any) -> Usage:
42
+ pass
43
+
44
+ @staticmethod
45
+ @abstractmethod
46
+ def from_assistant_message(raw_message: Any, **kwargs: Any) -> AssistantMessage:
44
47
  pass
45
48
 
46
49
  @staticmethod
@@ -103,10 +106,3 @@ class Converters(ABC):
103
106
  @abstractmethod
104
107
  def from_completion_chunk(raw_chunk: Any, **kwargs: Any) -> CompletionChunk:
105
108
  pass
106
-
107
- @staticmethod
108
- @abstractmethod
109
- def from_completion_chunk_iterator(
110
- raw_chunk_iterator: AsyncIterator[Any], **kwargs: Any
111
- ) -> AsyncIterator[CompletionChunk]:
112
- pass
@@ -0,0 +1,86 @@
1
+ import time
2
+ from enum import StrEnum
3
+ from typing import Any, Generic, Literal, TypeVar
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+ from ..packet import Packet
8
+ from .completion import Completion
9
+ from .completion_chunk import CompletionChunk
10
+ from .message import AssistantMessage, SystemMessage, ToolCall, ToolMessage, UserMessage
11
+
12
+
13
+ class EventSourceType(StrEnum):
14
+ LLM = "llm"
15
+ AGENT = "agent"
16
+ USER = "user"
17
+ TOOL = "tool"
18
+ PROCESSOR = "processor"
19
+
20
+
21
+ class EventType(StrEnum):
22
+ SYS_MSG = "system_message"
23
+ USR_MSG = "user_message"
24
+ TOOL_MSG = "tool_message"
25
+ TOOL_CALL = "tool_call"
26
+ GEN_MSG = "gen_message"
27
+ COMP = "completion"
28
+ COMP_CHUNK = "completion_chunk"
29
+ PACKET = "packet"
30
+ PROC_OUT = "processor_output"
31
+
32
+
33
+ _T = TypeVar("_T")
34
+
35
+
36
+ class Event(BaseModel, Generic[_T], frozen=True):
37
+ type: EventType
38
+ source: EventSourceType
39
+ created: int = Field(default_factory=lambda: int(time.time()))
40
+ name: str | None = None
41
+ data: _T
42
+
43
+
44
+ class CompletionEvent(Event[Completion], frozen=True):
45
+ type: Literal[EventType.COMP] = EventType.COMP
46
+ source: Literal[EventSourceType.LLM] = EventSourceType.LLM
47
+
48
+
49
+ class CompletionChunkEvent(Event[CompletionChunk], frozen=True):
50
+ type: Literal[EventType.COMP_CHUNK] = EventType.COMP_CHUNK
51
+ source: Literal[EventSourceType.LLM] = EventSourceType.LLM
52
+
53
+
54
+ class GenMessageEvent(Event[AssistantMessage], frozen=True):
55
+ type: Literal[EventType.GEN_MSG] = EventType.GEN_MSG
56
+ source: Literal[EventSourceType.LLM] = EventSourceType.LLM
57
+
58
+
59
+ class ToolCallEvent(Event[ToolCall], frozen=True):
60
+ type: Literal[EventType.TOOL_CALL] = EventType.TOOL_CALL
61
+ source: Literal[EventSourceType.AGENT] = EventSourceType.AGENT
62
+
63
+
64
+ class ToolMessageEvent(Event[ToolMessage], frozen=True):
65
+ type: Literal[EventType.TOOL_MSG] = EventType.TOOL_MSG
66
+ source: Literal[EventSourceType.TOOL] = EventSourceType.TOOL
67
+
68
+
69
+ class UserMessageEvent(Event[UserMessage], frozen=True):
70
+ type: Literal[EventType.USR_MSG] = EventType.USR_MSG
71
+ source: Literal[EventSourceType.USER] = EventSourceType.USER
72
+
73
+
74
+ class SystemMessageEvent(Event[SystemMessage], frozen=True):
75
+ type: Literal[EventType.SYS_MSG] = EventType.SYS_MSG
76
+ source: Literal[EventSourceType.AGENT] = EventSourceType.AGENT
77
+
78
+
79
+ class PacketEvent(Event[Packet[Any]], frozen=True):
80
+ type: Literal[EventType.PACKET] = EventType.PACKET
81
+ source: Literal[EventSourceType.PROCESSOR] = EventSourceType.PROCESSOR
82
+
83
+
84
+ class ProcOutputEvent(Event[Any], frozen=True):
85
+ type: Literal[EventType.PROC_OUT] = EventType.PROC_OUT
86
+ source: Literal[EventSourceType.PROCESSOR] = EventSourceType.PROCESSOR
grasp_agents/typing/io.py CHANGED
@@ -1,25 +1,16 @@
1
- from collections.abc import Mapping
2
1
  from typing import TypeAlias, TypeVar
3
2
 
4
3
  from pydantic import BaseModel
5
4
 
6
- from .content import ImageData
7
-
8
- AgentID: TypeAlias = str
9
-
10
-
11
- class AgentState(BaseModel):
12
- pass
5
+ ProcName: TypeAlias = str
13
6
 
14
7
 
15
8
  class LLMPromptArgs(BaseModel):
16
9
  pass
17
10
 
18
11
 
19
- InT = TypeVar("InT", contravariant=True) # noqa: PLC0105
20
- OutT = TypeVar("OutT", covariant=True) # noqa: PLC0105
21
- StateT = TypeVar("StateT", bound=AgentState, covariant=True) # noqa: PLC0105
12
+ InT_contra = TypeVar("InT_contra", contravariant=True)
13
+ OutT_co = TypeVar("OutT_co", covariant=True)
14
+ MemT_co = TypeVar("MemT_co", covariant=True)
22
15
 
23
16
  LLMPrompt: TypeAlias = str
24
- LLMFormattedSystemArgs: TypeAlias = Mapping[str, str | int | bool]
25
- LLMFormattedArgs: TypeAlias = Mapping[str, str | int | bool | ImageData]
@@ -4,7 +4,7 @@ from enum import StrEnum
4
4
  from typing import Annotated, Any, Literal, TypeAlias
5
5
  from uuid import uuid4
6
6
 
7
- from pydantic import BaseModel, Field, NonNegativeFloat, NonNegativeInt
7
+ from pydantic import BaseModel, Field
8
8
  from pydantic.json import pydantic_encoder
9
9
 
10
10
  from .content import Content, ImageData
@@ -18,85 +18,48 @@ class Role(StrEnum):
18
18
  TOOL = "tool"
19
19
 
20
20
 
21
- class Usage(BaseModel):
22
- input_tokens: NonNegativeInt = 0
23
- output_tokens: NonNegativeInt = 0
24
- reasoning_tokens: NonNegativeInt | None = None
25
- cached_tokens: NonNegativeInt | None = None
26
- cost: NonNegativeFloat | None = None
27
-
28
- def __add__(self, add_usage: "Usage") -> "Usage":
29
- input_tokens = self.input_tokens + add_usage.input_tokens
30
- output_tokens = self.output_tokens + add_usage.output_tokens
31
- if self.reasoning_tokens is not None or add_usage.reasoning_tokens is not None:
32
- reasoning_tokens = (self.reasoning_tokens or 0) + (
33
- add_usage.reasoning_tokens or 0
34
- )
35
- else:
36
- reasoning_tokens = None
37
-
38
- if self.cached_tokens is not None or add_usage.cached_tokens is not None:
39
- cached_tokens = (self.cached_tokens or 0) + (add_usage.cached_tokens or 0)
40
- else:
41
- cached_tokens = None
42
-
43
- cost = (
44
- (self.cost or 0.0) + add_usage.cost
45
- if (add_usage.cost is not None)
46
- else None
47
- )
48
- return Usage(
49
- input_tokens=input_tokens,
50
- output_tokens=output_tokens,
51
- reasoning_tokens=reasoning_tokens,
52
- cached_tokens=cached_tokens,
53
- cost=cost,
54
- )
55
-
56
-
57
21
  class MessageBase(BaseModel):
58
- message_id: Hashable = Field(default_factory=lambda: str(uuid4())[:8])
59
- model_id: str | None = None
22
+ id: Hashable = Field(default_factory=lambda: str(uuid4())[:8])
23
+ name: str | None = None
60
24
 
61
25
 
62
26
  class AssistantMessage(MessageBase):
63
27
  role: Literal[Role.ASSISTANT] = Role.ASSISTANT
64
28
  content: str | None
65
- usage: Usage | None = None
66
29
  tool_calls: Sequence[ToolCall] | None = None
67
30
  refusal: str | None = None
68
31
 
69
32
 
70
33
  class UserMessage(MessageBase):
71
34
  role: Literal[Role.USER] = Role.USER
72
- content: Content
35
+ content: Content | str
73
36
 
74
37
  @classmethod
75
- def from_text(cls, text: str, model_id: str | None = None) -> "UserMessage":
76
- return cls(content=Content.from_text(text), model_id=model_id)
38
+ def from_text(cls, text: str, name: str | None = None) -> "UserMessage":
39
+ return cls(content=Content.from_text(text), name=name)
77
40
 
78
41
  @classmethod
79
42
  def from_formatted_prompt(
80
43
  cls,
81
44
  prompt_template: str,
82
45
  prompt_args: Mapping[str, str | int | bool | ImageData] | None = None,
83
- model_id: str | None = None,
46
+ name: str | None = None,
84
47
  ) -> "UserMessage":
85
48
  content = Content.from_formatted_prompt(
86
49
  prompt_template=prompt_template, prompt_args=prompt_args
87
50
  )
88
51
 
89
- return cls(content=content, model_id=model_id)
52
+ return cls(content=content, name=name)
90
53
 
91
54
  @classmethod
92
55
  def from_content_parts(
93
56
  cls,
94
57
  content_parts: Sequence[str | ImageData],
95
- model_id: str | None = None,
58
+ name: str | None = None,
96
59
  ) -> "UserMessage":
97
60
  content = Content.from_content_parts(content_parts)
98
61
 
99
- return cls(content=content, model_id=model_id)
62
+ return cls(content=content, name=name)
100
63
 
101
64
 
102
65
  class SystemMessage(MessageBase):
@@ -114,13 +77,12 @@ class ToolMessage(MessageBase):
114
77
  cls,
115
78
  tool_output: Any,
116
79
  tool_call: ToolCall,
117
- model_id: str | None = None,
118
80
  indent: int = 2,
119
81
  ) -> "ToolMessage":
120
82
  return cls(
121
83
  content=json.dumps(tool_output, default=pydantic_encoder, indent=indent),
122
84
  tool_call_id=tool_call.id,
123
- model_id=model_id,
85
+ name=tool_call.tool_name,
124
86
  )
125
87
 
126
88
 
@@ -129,4 +91,4 @@ Message = Annotated[
129
91
  Field(discriminator="role"),
130
92
  ]
131
93
 
132
- Conversation: TypeAlias = list[Message]
94
+ Messages: TypeAlias = list[Message]
@@ -1,23 +1,31 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeAlias, TypeVar
4
+ from typing import (
5
+ TYPE_CHECKING,
6
+ Any,
7
+ ClassVar,
8
+ Generic,
9
+ Literal,
10
+ TypeAlias,
11
+ TypeVar,
12
+ )
5
13
 
6
14
  from pydantic import BaseModel, PrivateAttr, TypeAdapter
7
15
 
8
16
  from ..generics_utils import AutoInstanceAttributesMixin
9
17
 
10
18
  if TYPE_CHECKING:
11
- from ..run_context import CtxT, RunContextWrapper
19
+ from ..run_context import CtxT, RunContext
12
20
  else:
13
21
  CtxT = TypeVar("CtxT")
14
22
 
15
- class RunContextWrapper(Generic[CtxT]):
16
- """Runtime placeholder so RunContextWrapper[CtxT] works"""
23
+ class RunContext(Generic[CtxT]):
24
+ """Runtime placeholder so RunContext[CtxT] works"""
17
25
 
18
26
 
19
- _ToolInT = TypeVar("_ToolInT", bound=BaseModel, contravariant=True) # noqa: PLC0105
20
- _ToolOutT = TypeVar("_ToolOutT", covariant=True) # noqa: PLC0105
27
+ _InT_contra = TypeVar("_InT_contra", bound=BaseModel, contravariant=True)
28
+ _OutT_co = TypeVar("_OutT_co", covariant=True)
21
29
 
22
30
 
23
31
  class ToolCall(BaseModel):
@@ -27,45 +35,63 @@ class ToolCall(BaseModel):
27
35
 
28
36
 
29
37
  class BaseTool(
30
- AutoInstanceAttributesMixin, BaseModel, ABC, Generic[_ToolInT, _ToolOutT, CtxT]
38
+ AutoInstanceAttributesMixin,
39
+ BaseModel,
40
+ ABC,
41
+ Generic[_InT_contra, _OutT_co, CtxT],
31
42
  ):
32
43
  _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
33
- 0: "_in_schema",
34
- 1: "_out_schema",
44
+ 0: "_in_type",
45
+ 1: "_out_type",
35
46
  }
36
47
 
37
48
  name: str
38
49
  description: str
39
50
 
40
- _in_schema: type[_ToolInT] = PrivateAttr()
41
- _out_schema: type[_ToolOutT] = PrivateAttr()
51
+ _in_type: type[_InT_contra] = PrivateAttr()
52
+ _out_type: type[_OutT_co] = PrivateAttr()
53
+
54
+ # _in_type_adapter: TypeAdapter[_InT_contra] = PrivateAttr()
55
+ # _out_type_adapter: TypeAdapter[_OutT_co] = PrivateAttr()
42
56
 
43
- # Supported by OpenAI API
44
- strict: bool | None = None
57
+ # def model_post_init(self, context: Any) -> None:
58
+ # self._in_type_adapter = TypeAdapter(self._in_type)
59
+ # self._out_type_adapter = TypeAdapter(self._out_type)
45
60
 
46
61
  @property
47
- def in_schema(self) -> type[_ToolInT]: # type: ignore[reportInvalidTypeVarUse]
62
+ def in_type(self) -> type[_InT_contra]: # type: ignore[reportInvalidTypeVarUse]
48
63
  # Exposing the type of a contravariant variable only, should be type safe
49
- return self._in_schema
64
+ return self._in_type
50
65
 
51
66
  @property
52
- def out_schema(self) -> type[_ToolOutT]:
53
- return self._out_schema
67
+ def out_type(self) -> type[_OutT_co]:
68
+ return self._out_type
69
+
70
+ # @property
71
+ # def in_type_adapter(self) -> TypeAdapter[_InT_contra]:
72
+ # return self._in_type_adapter
73
+
74
+ # @property
75
+ # def out_type_adapter(self) -> TypeAdapter[_OutT_co]:
76
+ # return self._out_type_adapter
54
77
 
55
78
  @abstractmethod
56
79
  async def run(
57
- self, inp: _ToolInT, ctx: RunContextWrapper[CtxT] | None = None
58
- ) -> _ToolOutT:
80
+ self, inp: _InT_contra, ctx: RunContext[CtxT] | None = None
81
+ ) -> _OutT_co:
59
82
  pass
60
83
 
61
84
  async def __call__(
62
- self, ctx: RunContextWrapper[CtxT] | None = None, **kwargs: Any
63
- ) -> _ToolOutT:
64
- result = await self.run(self._in_schema(**kwargs), ctx=ctx)
85
+ self, ctx: RunContext[CtxT] | None = None, **kwargs: Any
86
+ ) -> _OutT_co:
87
+ input_args = TypeAdapter(self._in_type).validate_python(kwargs)
88
+ output = await self.run(input_args, ctx=ctx)
65
89
 
66
- return TypeAdapter(self._out_schema).validate_python(result)
90
+ return TypeAdapter(self._out_type).validate_python(output)
67
91
 
68
92
 
69
- ToolChoice: TypeAlias = (
70
- Literal["none", "auto", "required"] | BaseTool[BaseModel, Any, Any]
71
- )
93
+ class NamedToolChoice(BaseModel):
94
+ name: str
95
+
96
+
97
+ ToolChoice: TypeAlias = Literal["none", "auto", "required"] | NamedToolChoice
@@ -7,7 +7,7 @@ import yaml
7
7
  from pydantic import BaseModel, Field
8
8
  from termcolor import colored
9
9
 
10
- from .typing.message import AssistantMessage, Message, Usage
10
+ from .typing.completion import Completion, Usage
11
11
 
12
12
  logger = logging.getLogger(__name__)
13
13
 
@@ -58,20 +58,20 @@ class UsageTracker(BaseModel):
58
58
  usage.cost = (input_cost + output_cost + reasoning_cost + cached_cost) / 1e6
59
59
 
60
60
  def update(
61
- self, messages: Sequence[Message], model_name: str | None = None
61
+ self, completions: Sequence[Completion], model_name: str | None = None
62
62
  ) -> None:
63
63
  if model_name is not None and self.costs_dict is not None:
64
64
  model_costs_dict = self.costs_dict.get(model_name.split(":", 1)[-1])
65
65
  else:
66
66
  model_costs_dict = None
67
67
 
68
- for message in messages:
69
- if isinstance(message, AssistantMessage) and message.usage is not None:
68
+ for completion in completions:
69
+ if completion.usage is not None:
70
70
  if model_costs_dict is not None:
71
71
  self._add_cost_to_usage(
72
- usage=message.usage, model_costs_dict=model_costs_dict
72
+ usage=completion.usage, model_costs_dict=model_costs_dict
73
73
  )
74
- self.total_usage += message.usage
74
+ self.total_usage += completion.usage
75
75
 
76
76
  def reset(self) -> None:
77
77
  self.total_usage = Usage()
grasp_agents/utils.py CHANGED
@@ -3,7 +3,7 @@ import asyncio
3
3
  import json
4
4
  import re
5
5
  from collections.abc import Coroutine, Mapping
6
- from datetime import datetime
6
+ from datetime import UTC, datetime
7
7
  from logging import getLogger
8
8
  from pathlib import Path
9
9
  from typing import Any, TypeVar, overload
@@ -126,7 +126,7 @@ def read_contents_from_file(
126
126
  return Path(file_path).read_bytes()
127
127
  return Path(file_path).read_text()
128
128
  except FileNotFoundError:
129
- logger.error(f"File {file_path} not found.")
129
+ logger.exception(f"File {file_path} not found.")
130
130
  return ""
131
131
 
132
132
 
@@ -157,4 +157,4 @@ async def asyncio_gather_with_pbar(
157
157
 
158
158
 
159
159
  def get_timestamp() -> str:
160
- return datetime.now().strftime("%Y%m%d_%H%M%S")
160
+ return datetime.now(UTC).strftime("%Y%m%d_%H%M%S")