grasp_agents 0.2.11__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +15 -14
- grasp_agents/cloud_llm.py +118 -131
- grasp_agents/comm_processor.py +201 -0
- grasp_agents/generics_utils.py +15 -7
- grasp_agents/llm.py +60 -31
- grasp_agents/llm_agent.py +229 -273
- grasp_agents/llm_agent_memory.py +58 -0
- grasp_agents/llm_policy_executor.py +482 -0
- grasp_agents/memory.py +20 -134
- grasp_agents/message_history.py +140 -0
- grasp_agents/openai/__init__.py +54 -36
- grasp_agents/openai/completion_chunk_converters.py +78 -0
- grasp_agents/openai/completion_converters.py +53 -30
- grasp_agents/openai/content_converters.py +13 -14
- grasp_agents/openai/converters.py +44 -68
- grasp_agents/openai/message_converters.py +58 -72
- grasp_agents/openai/openai_llm.py +101 -42
- grasp_agents/openai/tool_converters.py +24 -19
- grasp_agents/packet.py +24 -0
- grasp_agents/packet_pool.py +91 -0
- grasp_agents/printer.py +29 -15
- grasp_agents/processor.py +193 -0
- grasp_agents/prompt_builder.py +175 -192
- grasp_agents/run_context.py +20 -37
- grasp_agents/typing/completion.py +58 -12
- grasp_agents/typing/completion_chunk.py +173 -0
- grasp_agents/typing/converters.py +8 -12
- grasp_agents/typing/events.py +86 -0
- grasp_agents/typing/io.py +4 -13
- grasp_agents/typing/message.py +12 -50
- grasp_agents/typing/tool.py +52 -26
- grasp_agents/usage_tracker.py +6 -6
- grasp_agents/utils.py +3 -3
- grasp_agents/workflow/looped_workflow.py +132 -0
- grasp_agents/workflow/parallel_processor.py +95 -0
- grasp_agents/workflow/sequential_workflow.py +66 -0
- grasp_agents/workflow/workflow_processor.py +78 -0
- {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.2.dist-info}/METADATA +41 -50
- grasp_agents-0.3.2.dist-info/RECORD +51 -0
- grasp_agents/agent_message.py +0 -27
- grasp_agents/agent_message_pool.py +0 -92
- grasp_agents/base_agent.py +0 -51
- grasp_agents/comm_agent.py +0 -217
- grasp_agents/llm_agent_state.py +0 -79
- grasp_agents/tool_orchestrator.py +0 -203
- grasp_agents/workflow/looped_agent.py +0 -134
- grasp_agents/workflow/sequential_agent.py +0 -72
- grasp_agents/workflow/workflow_agent.py +0 -88
- grasp_agents-0.2.11.dist-info/RECORD +0 -46
- {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.2.dist-info}/WHEEL +0 -0
- {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.2.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,173 @@
|
|
1
|
+
import time
|
2
|
+
from collections import defaultdict
|
3
|
+
from collections.abc import Sequence
|
4
|
+
from uuid import uuid4
|
5
|
+
|
6
|
+
from openai.types.chat.chat_completion_chunk import (
|
7
|
+
ChoiceLogprobs as CompletionChunkChoiceLogprobs,
|
8
|
+
)
|
9
|
+
from openai.types.chat.chat_completion_token_logprob import (
|
10
|
+
ChatCompletionTokenLogprob as CompletionTokenLogprob,
|
11
|
+
)
|
12
|
+
from pydantic import BaseModel, Field
|
13
|
+
|
14
|
+
from .completion import (
|
15
|
+
Completion,
|
16
|
+
CompletionChoice,
|
17
|
+
CompletionChoiceLogprobs,
|
18
|
+
FinishReason,
|
19
|
+
Usage,
|
20
|
+
)
|
21
|
+
from .message import AssistantMessage, ToolCall
|
22
|
+
|
23
|
+
|
24
|
+
class CompletionChunkDeltaToolCall(BaseModel):
|
25
|
+
id: str | None
|
26
|
+
index: int
|
27
|
+
tool_name: str | None
|
28
|
+
tool_arguments: str | None
|
29
|
+
|
30
|
+
|
31
|
+
class CompletionChunkChoiceDelta(BaseModel):
|
32
|
+
content: str | None = None
|
33
|
+
refusal: str | None
|
34
|
+
role: str | None
|
35
|
+
tool_calls: list[CompletionChunkDeltaToolCall] | None
|
36
|
+
|
37
|
+
|
38
|
+
class CompletionChunkChoice(BaseModel):
|
39
|
+
delta: CompletionChunkChoiceDelta
|
40
|
+
finish_reason: FinishReason | None
|
41
|
+
index: int
|
42
|
+
logprobs: CompletionChunkChoiceLogprobs | None = None
|
43
|
+
|
44
|
+
|
45
|
+
class CompletionChunk(BaseModel):
|
46
|
+
id: str = Field(default_factory=lambda: str(uuid4())[:8])
|
47
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
48
|
+
model: str
|
49
|
+
name: str | None = None
|
50
|
+
system_fingerprint: str | None = None
|
51
|
+
choices: list[CompletionChunkChoice]
|
52
|
+
usage: Usage | None = None
|
53
|
+
|
54
|
+
|
55
|
+
def combine_completion_chunks(chunks: list[CompletionChunk]) -> Completion:
|
56
|
+
if not chunks:
|
57
|
+
raise ValueError("Cannot combine an empty list of completion chunks.")
|
58
|
+
|
59
|
+
model_list = {chunk.model for chunk in chunks}
|
60
|
+
if len(model_list) > 1:
|
61
|
+
raise ValueError("All chunks must have the same model.")
|
62
|
+
model = model_list.pop()
|
63
|
+
|
64
|
+
name_list = {chunk.name for chunk in chunks}
|
65
|
+
if len(name_list) > 1:
|
66
|
+
raise ValueError("All chunks must have the same name.")
|
67
|
+
name = name_list.pop()
|
68
|
+
|
69
|
+
system_fingerprints_list = {chunk.system_fingerprint for chunk in chunks}
|
70
|
+
if len(system_fingerprints_list) > 1:
|
71
|
+
raise ValueError("All chunks must have the same system fingerprint.")
|
72
|
+
system_fingerprint = system_fingerprints_list.pop()
|
73
|
+
|
74
|
+
created_list = [chunk.created for chunk in chunks]
|
75
|
+
created = max(created_list)
|
76
|
+
|
77
|
+
# Usage is found in the last completion chunk if requested
|
78
|
+
usage = chunks[-1].usage
|
79
|
+
|
80
|
+
logp_contents_per_choice: defaultdict[int, list[CompletionTokenLogprob]] = (
|
81
|
+
defaultdict(list)
|
82
|
+
)
|
83
|
+
logp_refusals_per_choice: defaultdict[int, list[CompletionTokenLogprob]] = (
|
84
|
+
defaultdict(list)
|
85
|
+
)
|
86
|
+
logprobs_per_choice: defaultdict[int, CompletionChoiceLogprobs | None] = (
|
87
|
+
defaultdict(lambda: None)
|
88
|
+
)
|
89
|
+
|
90
|
+
finish_reasons_per_choice: defaultdict[int, FinishReason | None] = defaultdict(
|
91
|
+
lambda: None
|
92
|
+
)
|
93
|
+
|
94
|
+
contents_per_choice: defaultdict[int, str] = defaultdict(lambda: "")
|
95
|
+
refusals_per_choice: defaultdict[int, str] = defaultdict(lambda: "")
|
96
|
+
|
97
|
+
tool_calls_per_choice: defaultdict[
|
98
|
+
int, Sequence[CompletionChunkDeltaToolCall] | None
|
99
|
+
] = defaultdict(lambda: None)
|
100
|
+
|
101
|
+
messages_per_choice: dict[int, AssistantMessage] = {}
|
102
|
+
|
103
|
+
for chunk in chunks:
|
104
|
+
for choice in chunk.choices:
|
105
|
+
index = choice.index
|
106
|
+
|
107
|
+
# Concatenate content and refusal tokens for each choice
|
108
|
+
contents_per_choice[index] += choice.delta.content or ""
|
109
|
+
refusals_per_choice[index] += choice.delta.refusal or ""
|
110
|
+
|
111
|
+
# Concatenate logprobs for content and refusal tokens for each choice
|
112
|
+
if choice.logprobs is not None:
|
113
|
+
logp_contents_per_choice[index].extend(choice.logprobs.content or [])
|
114
|
+
logp_refusals_per_choice[index].extend(choice.logprobs.refusal or [])
|
115
|
+
|
116
|
+
# Take the last finish reason for each choice
|
117
|
+
finish_reasons_per_choice[index] = choice.finish_reason
|
118
|
+
|
119
|
+
# Tool calls should be in the last chunk for each choice
|
120
|
+
tool_calls_per_choice[index] = choice.delta.tool_calls
|
121
|
+
|
122
|
+
for index in finish_reasons_per_choice:
|
123
|
+
tool_calls: list[ToolCall] = []
|
124
|
+
if tool_calls_per_choice[index] is not None:
|
125
|
+
for _tool_call in tool_calls_per_choice[index]: # type: ignore
|
126
|
+
if (
|
127
|
+
_tool_call.id is None
|
128
|
+
or _tool_call.tool_name is None
|
129
|
+
or _tool_call.tool_arguments is None
|
130
|
+
):
|
131
|
+
raise ValueError(
|
132
|
+
"Completion chunk tool calls must have id, tool_name, "
|
133
|
+
"and tool_arguments set."
|
134
|
+
)
|
135
|
+
tool_calls.append(
|
136
|
+
ToolCall(
|
137
|
+
id=_tool_call.id,
|
138
|
+
tool_name=_tool_call.tool_name,
|
139
|
+
tool_arguments=_tool_call.tool_arguments,
|
140
|
+
)
|
141
|
+
)
|
142
|
+
|
143
|
+
messages_per_choice[index] = AssistantMessage(
|
144
|
+
name=name,
|
145
|
+
content=contents_per_choice[index] or "<empty>",
|
146
|
+
refusal=(refusals_per_choice[index] or None),
|
147
|
+
tool_calls=(tool_calls or None),
|
148
|
+
)
|
149
|
+
|
150
|
+
if logp_contents_per_choice[index] or logp_refusals_per_choice[index]:
|
151
|
+
logprobs_per_choice[index] = CompletionChoiceLogprobs(
|
152
|
+
content=logp_contents_per_choice[index],
|
153
|
+
refusal=logp_refusals_per_choice[index],
|
154
|
+
)
|
155
|
+
|
156
|
+
choices = [
|
157
|
+
CompletionChoice(
|
158
|
+
index=index,
|
159
|
+
message=message,
|
160
|
+
finish_reason=finish_reasons_per_choice[index],
|
161
|
+
logprobs=logprobs_per_choice[index],
|
162
|
+
)
|
163
|
+
for index, message in messages_per_choice.items()
|
164
|
+
]
|
165
|
+
|
166
|
+
return Completion(
|
167
|
+
model=model,
|
168
|
+
name=name,
|
169
|
+
created=created,
|
170
|
+
system_fingerprint=system_fingerprint,
|
171
|
+
choices=choices,
|
172
|
+
usage=usage,
|
173
|
+
)
|
@@ -1,10 +1,10 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from collections.abc import AsyncIterator
|
3
2
|
from typing import Any
|
4
3
|
|
5
4
|
from pydantic import BaseModel
|
6
5
|
|
7
|
-
from .completion import Completion,
|
6
|
+
from .completion import Completion, Usage
|
7
|
+
from .completion_chunk import CompletionChunk
|
8
8
|
from .content import Content
|
9
9
|
from .message import AssistantMessage, Message, SystemMessage, ToolMessage, UserMessage
|
10
10
|
from .tool import BaseTool, ToolChoice
|
@@ -38,9 +38,12 @@ class Converters(ABC):
|
|
38
38
|
|
39
39
|
@staticmethod
|
40
40
|
@abstractmethod
|
41
|
-
def
|
42
|
-
|
43
|
-
|
41
|
+
def from_completion_usage(raw_usage: Any, **kwargs: Any) -> Usage:
|
42
|
+
pass
|
43
|
+
|
44
|
+
@staticmethod
|
45
|
+
@abstractmethod
|
46
|
+
def from_assistant_message(raw_message: Any, **kwargs: Any) -> AssistantMessage:
|
44
47
|
pass
|
45
48
|
|
46
49
|
@staticmethod
|
@@ -103,10 +106,3 @@ class Converters(ABC):
|
|
103
106
|
@abstractmethod
|
104
107
|
def from_completion_chunk(raw_chunk: Any, **kwargs: Any) -> CompletionChunk:
|
105
108
|
pass
|
106
|
-
|
107
|
-
@staticmethod
|
108
|
-
@abstractmethod
|
109
|
-
def from_completion_chunk_iterator(
|
110
|
-
raw_chunk_iterator: AsyncIterator[Any], **kwargs: Any
|
111
|
-
) -> AsyncIterator[CompletionChunk]:
|
112
|
-
pass
|
@@ -0,0 +1,86 @@
|
|
1
|
+
import time
|
2
|
+
from enum import StrEnum
|
3
|
+
from typing import Any, Generic, Literal, TypeVar
|
4
|
+
|
5
|
+
from pydantic import BaseModel, Field
|
6
|
+
|
7
|
+
from ..packet import Packet
|
8
|
+
from .completion import Completion
|
9
|
+
from .completion_chunk import CompletionChunk
|
10
|
+
from .message import AssistantMessage, SystemMessage, ToolCall, ToolMessage, UserMessage
|
11
|
+
|
12
|
+
|
13
|
+
class EventSourceType(StrEnum):
|
14
|
+
LLM = "llm"
|
15
|
+
AGENT = "agent"
|
16
|
+
USER = "user"
|
17
|
+
TOOL = "tool"
|
18
|
+
PROCESSOR = "processor"
|
19
|
+
|
20
|
+
|
21
|
+
class EventType(StrEnum):
|
22
|
+
SYS_MSG = "system_message"
|
23
|
+
USR_MSG = "user_message"
|
24
|
+
TOOL_MSG = "tool_message"
|
25
|
+
TOOL_CALL = "tool_call"
|
26
|
+
GEN_MSG = "gen_message"
|
27
|
+
COMP = "completion"
|
28
|
+
COMP_CHUNK = "completion_chunk"
|
29
|
+
PACKET = "packet"
|
30
|
+
PROC_OUT = "processor_output"
|
31
|
+
|
32
|
+
|
33
|
+
_T = TypeVar("_T")
|
34
|
+
|
35
|
+
|
36
|
+
class Event(BaseModel, Generic[_T], frozen=True):
|
37
|
+
type: EventType
|
38
|
+
source: EventSourceType
|
39
|
+
created: int = Field(default_factory=lambda: int(time.time()))
|
40
|
+
name: str | None = None
|
41
|
+
data: _T
|
42
|
+
|
43
|
+
|
44
|
+
class CompletionEvent(Event[Completion], frozen=True):
|
45
|
+
type: Literal[EventType.COMP] = EventType.COMP
|
46
|
+
source: Literal[EventSourceType.LLM] = EventSourceType.LLM
|
47
|
+
|
48
|
+
|
49
|
+
class CompletionChunkEvent(Event[CompletionChunk], frozen=True):
|
50
|
+
type: Literal[EventType.COMP_CHUNK] = EventType.COMP_CHUNK
|
51
|
+
source: Literal[EventSourceType.LLM] = EventSourceType.LLM
|
52
|
+
|
53
|
+
|
54
|
+
class GenMessageEvent(Event[AssistantMessage], frozen=True):
|
55
|
+
type: Literal[EventType.GEN_MSG] = EventType.GEN_MSG
|
56
|
+
source: Literal[EventSourceType.LLM] = EventSourceType.LLM
|
57
|
+
|
58
|
+
|
59
|
+
class ToolCallEvent(Event[ToolCall], frozen=True):
|
60
|
+
type: Literal[EventType.TOOL_CALL] = EventType.TOOL_CALL
|
61
|
+
source: Literal[EventSourceType.AGENT] = EventSourceType.AGENT
|
62
|
+
|
63
|
+
|
64
|
+
class ToolMessageEvent(Event[ToolMessage], frozen=True):
|
65
|
+
type: Literal[EventType.TOOL_MSG] = EventType.TOOL_MSG
|
66
|
+
source: Literal[EventSourceType.TOOL] = EventSourceType.TOOL
|
67
|
+
|
68
|
+
|
69
|
+
class UserMessageEvent(Event[UserMessage], frozen=True):
|
70
|
+
type: Literal[EventType.USR_MSG] = EventType.USR_MSG
|
71
|
+
source: Literal[EventSourceType.USER] = EventSourceType.USER
|
72
|
+
|
73
|
+
|
74
|
+
class SystemMessageEvent(Event[SystemMessage], frozen=True):
|
75
|
+
type: Literal[EventType.SYS_MSG] = EventType.SYS_MSG
|
76
|
+
source: Literal[EventSourceType.AGENT] = EventSourceType.AGENT
|
77
|
+
|
78
|
+
|
79
|
+
class PacketEvent(Event[Packet[Any]], frozen=True):
|
80
|
+
type: Literal[EventType.PACKET] = EventType.PACKET
|
81
|
+
source: Literal[EventSourceType.PROCESSOR] = EventSourceType.PROCESSOR
|
82
|
+
|
83
|
+
|
84
|
+
class ProcOutputEvent(Event[Any], frozen=True):
|
85
|
+
type: Literal[EventType.PROC_OUT] = EventType.PROC_OUT
|
86
|
+
source: Literal[EventSourceType.PROCESSOR] = EventSourceType.PROCESSOR
|
grasp_agents/typing/io.py
CHANGED
@@ -1,25 +1,16 @@
|
|
1
|
-
from collections.abc import Mapping
|
2
1
|
from typing import TypeAlias, TypeVar
|
3
2
|
|
4
3
|
from pydantic import BaseModel
|
5
4
|
|
6
|
-
|
7
|
-
|
8
|
-
AgentID: TypeAlias = str
|
9
|
-
|
10
|
-
|
11
|
-
class AgentState(BaseModel):
|
12
|
-
pass
|
5
|
+
ProcName: TypeAlias = str
|
13
6
|
|
14
7
|
|
15
8
|
class LLMPromptArgs(BaseModel):
|
16
9
|
pass
|
17
10
|
|
18
11
|
|
19
|
-
|
20
|
-
|
21
|
-
|
12
|
+
InT_contra = TypeVar("InT_contra", contravariant=True)
|
13
|
+
OutT_co = TypeVar("OutT_co", covariant=True)
|
14
|
+
MemT_co = TypeVar("MemT_co", covariant=True)
|
22
15
|
|
23
16
|
LLMPrompt: TypeAlias = str
|
24
|
-
LLMFormattedSystemArgs: TypeAlias = Mapping[str, str | int | bool]
|
25
|
-
LLMFormattedArgs: TypeAlias = Mapping[str, str | int | bool | ImageData]
|
grasp_agents/typing/message.py
CHANGED
@@ -4,7 +4,7 @@ from enum import StrEnum
|
|
4
4
|
from typing import Annotated, Any, Literal, TypeAlias
|
5
5
|
from uuid import uuid4
|
6
6
|
|
7
|
-
from pydantic import BaseModel, Field
|
7
|
+
from pydantic import BaseModel, Field
|
8
8
|
from pydantic.json import pydantic_encoder
|
9
9
|
|
10
10
|
from .content import Content, ImageData
|
@@ -18,85 +18,48 @@ class Role(StrEnum):
|
|
18
18
|
TOOL = "tool"
|
19
19
|
|
20
20
|
|
21
|
-
class Usage(BaseModel):
|
22
|
-
input_tokens: NonNegativeInt = 0
|
23
|
-
output_tokens: NonNegativeInt = 0
|
24
|
-
reasoning_tokens: NonNegativeInt | None = None
|
25
|
-
cached_tokens: NonNegativeInt | None = None
|
26
|
-
cost: NonNegativeFloat | None = None
|
27
|
-
|
28
|
-
def __add__(self, add_usage: "Usage") -> "Usage":
|
29
|
-
input_tokens = self.input_tokens + add_usage.input_tokens
|
30
|
-
output_tokens = self.output_tokens + add_usage.output_tokens
|
31
|
-
if self.reasoning_tokens is not None or add_usage.reasoning_tokens is not None:
|
32
|
-
reasoning_tokens = (self.reasoning_tokens or 0) + (
|
33
|
-
add_usage.reasoning_tokens or 0
|
34
|
-
)
|
35
|
-
else:
|
36
|
-
reasoning_tokens = None
|
37
|
-
|
38
|
-
if self.cached_tokens is not None or add_usage.cached_tokens is not None:
|
39
|
-
cached_tokens = (self.cached_tokens or 0) + (add_usage.cached_tokens or 0)
|
40
|
-
else:
|
41
|
-
cached_tokens = None
|
42
|
-
|
43
|
-
cost = (
|
44
|
-
(self.cost or 0.0) + add_usage.cost
|
45
|
-
if (add_usage.cost is not None)
|
46
|
-
else None
|
47
|
-
)
|
48
|
-
return Usage(
|
49
|
-
input_tokens=input_tokens,
|
50
|
-
output_tokens=output_tokens,
|
51
|
-
reasoning_tokens=reasoning_tokens,
|
52
|
-
cached_tokens=cached_tokens,
|
53
|
-
cost=cost,
|
54
|
-
)
|
55
|
-
|
56
|
-
|
57
21
|
class MessageBase(BaseModel):
|
58
|
-
|
59
|
-
|
22
|
+
id: Hashable = Field(default_factory=lambda: str(uuid4())[:8])
|
23
|
+
name: str | None = None
|
60
24
|
|
61
25
|
|
62
26
|
class AssistantMessage(MessageBase):
|
63
27
|
role: Literal[Role.ASSISTANT] = Role.ASSISTANT
|
64
28
|
content: str | None
|
65
|
-
usage: Usage | None = None
|
66
29
|
tool_calls: Sequence[ToolCall] | None = None
|
67
30
|
refusal: str | None = None
|
68
31
|
|
69
32
|
|
70
33
|
class UserMessage(MessageBase):
|
71
34
|
role: Literal[Role.USER] = Role.USER
|
72
|
-
content: Content
|
35
|
+
content: Content | str
|
73
36
|
|
74
37
|
@classmethod
|
75
|
-
def from_text(cls, text: str,
|
76
|
-
return cls(content=Content.from_text(text),
|
38
|
+
def from_text(cls, text: str, name: str | None = None) -> "UserMessage":
|
39
|
+
return cls(content=Content.from_text(text), name=name)
|
77
40
|
|
78
41
|
@classmethod
|
79
42
|
def from_formatted_prompt(
|
80
43
|
cls,
|
81
44
|
prompt_template: str,
|
82
45
|
prompt_args: Mapping[str, str | int | bool | ImageData] | None = None,
|
83
|
-
|
46
|
+
name: str | None = None,
|
84
47
|
) -> "UserMessage":
|
85
48
|
content = Content.from_formatted_prompt(
|
86
49
|
prompt_template=prompt_template, prompt_args=prompt_args
|
87
50
|
)
|
88
51
|
|
89
|
-
return cls(content=content,
|
52
|
+
return cls(content=content, name=name)
|
90
53
|
|
91
54
|
@classmethod
|
92
55
|
def from_content_parts(
|
93
56
|
cls,
|
94
57
|
content_parts: Sequence[str | ImageData],
|
95
|
-
|
58
|
+
name: str | None = None,
|
96
59
|
) -> "UserMessage":
|
97
60
|
content = Content.from_content_parts(content_parts)
|
98
61
|
|
99
|
-
return cls(content=content,
|
62
|
+
return cls(content=content, name=name)
|
100
63
|
|
101
64
|
|
102
65
|
class SystemMessage(MessageBase):
|
@@ -114,13 +77,12 @@ class ToolMessage(MessageBase):
|
|
114
77
|
cls,
|
115
78
|
tool_output: Any,
|
116
79
|
tool_call: ToolCall,
|
117
|
-
model_id: str | None = None,
|
118
80
|
indent: int = 2,
|
119
81
|
) -> "ToolMessage":
|
120
82
|
return cls(
|
121
83
|
content=json.dumps(tool_output, default=pydantic_encoder, indent=indent),
|
122
84
|
tool_call_id=tool_call.id,
|
123
|
-
|
85
|
+
name=tool_call.tool_name,
|
124
86
|
)
|
125
87
|
|
126
88
|
|
@@ -129,4 +91,4 @@ Message = Annotated[
|
|
129
91
|
Field(discriminator="role"),
|
130
92
|
]
|
131
93
|
|
132
|
-
|
94
|
+
Messages: TypeAlias = list[Message]
|
grasp_agents/typing/tool.py
CHANGED
@@ -1,23 +1,31 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
|
-
from typing import
|
4
|
+
from typing import (
|
5
|
+
TYPE_CHECKING,
|
6
|
+
Any,
|
7
|
+
ClassVar,
|
8
|
+
Generic,
|
9
|
+
Literal,
|
10
|
+
TypeAlias,
|
11
|
+
TypeVar,
|
12
|
+
)
|
5
13
|
|
6
14
|
from pydantic import BaseModel, PrivateAttr, TypeAdapter
|
7
15
|
|
8
16
|
from ..generics_utils import AutoInstanceAttributesMixin
|
9
17
|
|
10
18
|
if TYPE_CHECKING:
|
11
|
-
from ..run_context import CtxT,
|
19
|
+
from ..run_context import CtxT, RunContext
|
12
20
|
else:
|
13
21
|
CtxT = TypeVar("CtxT")
|
14
22
|
|
15
|
-
class
|
16
|
-
"""Runtime placeholder so
|
23
|
+
class RunContext(Generic[CtxT]):
|
24
|
+
"""Runtime placeholder so RunContext[CtxT] works"""
|
17
25
|
|
18
26
|
|
19
|
-
|
20
|
-
|
27
|
+
_InT_contra = TypeVar("_InT_contra", bound=BaseModel, contravariant=True)
|
28
|
+
_OutT_co = TypeVar("_OutT_co", covariant=True)
|
21
29
|
|
22
30
|
|
23
31
|
class ToolCall(BaseModel):
|
@@ -27,45 +35,63 @@ class ToolCall(BaseModel):
|
|
27
35
|
|
28
36
|
|
29
37
|
class BaseTool(
|
30
|
-
AutoInstanceAttributesMixin,
|
38
|
+
AutoInstanceAttributesMixin,
|
39
|
+
BaseModel,
|
40
|
+
ABC,
|
41
|
+
Generic[_InT_contra, _OutT_co, CtxT],
|
31
42
|
):
|
32
43
|
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
33
|
-
0: "
|
34
|
-
1: "
|
44
|
+
0: "_in_type",
|
45
|
+
1: "_out_type",
|
35
46
|
}
|
36
47
|
|
37
48
|
name: str
|
38
49
|
description: str
|
39
50
|
|
40
|
-
|
41
|
-
|
51
|
+
_in_type: type[_InT_contra] = PrivateAttr()
|
52
|
+
_out_type: type[_OutT_co] = PrivateAttr()
|
53
|
+
|
54
|
+
# _in_type_adapter: TypeAdapter[_InT_contra] = PrivateAttr()
|
55
|
+
# _out_type_adapter: TypeAdapter[_OutT_co] = PrivateAttr()
|
42
56
|
|
43
|
-
#
|
44
|
-
|
57
|
+
# def model_post_init(self, context: Any) -> None:
|
58
|
+
# self._in_type_adapter = TypeAdapter(self._in_type)
|
59
|
+
# self._out_type_adapter = TypeAdapter(self._out_type)
|
45
60
|
|
46
61
|
@property
|
47
|
-
def
|
62
|
+
def in_type(self) -> type[_InT_contra]: # type: ignore[reportInvalidTypeVarUse]
|
48
63
|
# Exposing the type of a contravariant variable only, should be type safe
|
49
|
-
return self.
|
64
|
+
return self._in_type
|
50
65
|
|
51
66
|
@property
|
52
|
-
def
|
53
|
-
return self.
|
67
|
+
def out_type(self) -> type[_OutT_co]:
|
68
|
+
return self._out_type
|
69
|
+
|
70
|
+
# @property
|
71
|
+
# def in_type_adapter(self) -> TypeAdapter[_InT_contra]:
|
72
|
+
# return self._in_type_adapter
|
73
|
+
|
74
|
+
# @property
|
75
|
+
# def out_type_adapter(self) -> TypeAdapter[_OutT_co]:
|
76
|
+
# return self._out_type_adapter
|
54
77
|
|
55
78
|
@abstractmethod
|
56
79
|
async def run(
|
57
|
-
self, inp:
|
58
|
-
) ->
|
80
|
+
self, inp: _InT_contra, ctx: RunContext[CtxT] | None = None
|
81
|
+
) -> _OutT_co:
|
59
82
|
pass
|
60
83
|
|
61
84
|
async def __call__(
|
62
|
-
self, ctx:
|
63
|
-
) ->
|
64
|
-
|
85
|
+
self, ctx: RunContext[CtxT] | None = None, **kwargs: Any
|
86
|
+
) -> _OutT_co:
|
87
|
+
input_args = TypeAdapter(self._in_type).validate_python(kwargs)
|
88
|
+
output = await self.run(input_args, ctx=ctx)
|
65
89
|
|
66
|
-
return TypeAdapter(self.
|
90
|
+
return TypeAdapter(self._out_type).validate_python(output)
|
67
91
|
|
68
92
|
|
69
|
-
|
70
|
-
|
71
|
-
|
93
|
+
class NamedToolChoice(BaseModel):
|
94
|
+
name: str
|
95
|
+
|
96
|
+
|
97
|
+
ToolChoice: TypeAlias = Literal["none", "auto", "required"] | NamedToolChoice
|
grasp_agents/usage_tracker.py
CHANGED
@@ -7,7 +7,7 @@ import yaml
|
|
7
7
|
from pydantic import BaseModel, Field
|
8
8
|
from termcolor import colored
|
9
9
|
|
10
|
-
from .typing.
|
10
|
+
from .typing.completion import Completion, Usage
|
11
11
|
|
12
12
|
logger = logging.getLogger(__name__)
|
13
13
|
|
@@ -58,20 +58,20 @@ class UsageTracker(BaseModel):
|
|
58
58
|
usage.cost = (input_cost + output_cost + reasoning_cost + cached_cost) / 1e6
|
59
59
|
|
60
60
|
def update(
|
61
|
-
self,
|
61
|
+
self, completions: Sequence[Completion], model_name: str | None = None
|
62
62
|
) -> None:
|
63
63
|
if model_name is not None and self.costs_dict is not None:
|
64
64
|
model_costs_dict = self.costs_dict.get(model_name.split(":", 1)[-1])
|
65
65
|
else:
|
66
66
|
model_costs_dict = None
|
67
67
|
|
68
|
-
for
|
69
|
-
if
|
68
|
+
for completion in completions:
|
69
|
+
if completion.usage is not None:
|
70
70
|
if model_costs_dict is not None:
|
71
71
|
self._add_cost_to_usage(
|
72
|
-
usage=
|
72
|
+
usage=completion.usage, model_costs_dict=model_costs_dict
|
73
73
|
)
|
74
|
-
self.total_usage +=
|
74
|
+
self.total_usage += completion.usage
|
75
75
|
|
76
76
|
def reset(self) -> None:
|
77
77
|
self.total_usage = Usage()
|
grasp_agents/utils.py
CHANGED
@@ -3,7 +3,7 @@ import asyncio
|
|
3
3
|
import json
|
4
4
|
import re
|
5
5
|
from collections.abc import Coroutine, Mapping
|
6
|
-
from datetime import datetime
|
6
|
+
from datetime import UTC, datetime
|
7
7
|
from logging import getLogger
|
8
8
|
from pathlib import Path
|
9
9
|
from typing import Any, TypeVar, overload
|
@@ -126,7 +126,7 @@ def read_contents_from_file(
|
|
126
126
|
return Path(file_path).read_bytes()
|
127
127
|
return Path(file_path).read_text()
|
128
128
|
except FileNotFoundError:
|
129
|
-
logger.
|
129
|
+
logger.exception(f"File {file_path} not found.")
|
130
130
|
return ""
|
131
131
|
|
132
132
|
|
@@ -157,4 +157,4 @@ async def asyncio_gather_with_pbar(
|
|
157
157
|
|
158
158
|
|
159
159
|
def get_timestamp() -> str:
|
160
|
-
return datetime.now().strftime("%Y%m%d_%H%M%S")
|
160
|
+
return datetime.now(UTC).strftime("%Y%m%d_%H%M%S")
|