grasp_agents 0.4.6__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/cloud_llm.py +191 -218
- grasp_agents/comm_processor.py +101 -100
- grasp_agents/errors.py +69 -9
- grasp_agents/litellm/__init__.py +106 -0
- grasp_agents/litellm/completion_chunk_converters.py +68 -0
- grasp_agents/litellm/completion_converters.py +72 -0
- grasp_agents/litellm/converters.py +138 -0
- grasp_agents/litellm/lite_llm.py +210 -0
- grasp_agents/litellm/message_converters.py +66 -0
- grasp_agents/llm.py +84 -49
- grasp_agents/llm_agent.py +136 -120
- grasp_agents/llm_agent_memory.py +3 -3
- grasp_agents/llm_policy_executor.py +167 -174
- grasp_agents/memory.py +4 -0
- grasp_agents/openai/__init__.py +24 -9
- grasp_agents/openai/completion_chunk_converters.py +6 -6
- grasp_agents/openai/completion_converters.py +12 -14
- grasp_agents/openai/content_converters.py +1 -3
- grasp_agents/openai/converters.py +6 -8
- grasp_agents/openai/message_converters.py +21 -3
- grasp_agents/openai/openai_llm.py +155 -103
- grasp_agents/openai/tool_converters.py +4 -6
- grasp_agents/packet.py +5 -2
- grasp_agents/packet_pool.py +14 -13
- grasp_agents/printer.py +234 -72
- grasp_agents/processor.py +228 -88
- grasp_agents/prompt_builder.py +2 -2
- grasp_agents/run_context.py +11 -20
- grasp_agents/runner.py +42 -0
- grasp_agents/typing/completion.py +16 -9
- grasp_agents/typing/completion_chunk.py +51 -22
- grasp_agents/typing/events.py +95 -19
- grasp_agents/typing/message.py +25 -1
- grasp_agents/typing/tool.py +2 -0
- grasp_agents/usage_tracker.py +31 -37
- grasp_agents/utils.py +95 -84
- grasp_agents/workflow/looped_workflow.py +60 -11
- grasp_agents/workflow/sequential_workflow.py +43 -11
- grasp_agents/workflow/workflow_processor.py +25 -24
- {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/METADATA +7 -6
- grasp_agents-0.5.0.dist-info/RECORD +57 -0
- grasp_agents-0.4.6.dist-info/RECORD +0 -50
- {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/WHEEL +0 -0
- {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,25 +1,31 @@
|
|
1
1
|
import time
|
2
2
|
from collections import defaultdict
|
3
3
|
from collections.abc import Sequence
|
4
|
+
from typing import Any
|
4
5
|
from uuid import uuid4
|
5
6
|
|
7
|
+
from litellm import ChatCompletionAnnotation as LiteLLMAnnotation
|
8
|
+
from litellm.types.utils import ChoiceLogprobs as LiteLLMChoiceLogprobs
|
9
|
+
from openai.types.chat.chat_completion import (
|
10
|
+
ChoiceLogprobs as OpenAIChoiceLogprobs,
|
11
|
+
)
|
6
12
|
from openai.types.chat.chat_completion_chunk import (
|
7
|
-
ChoiceLogprobs as
|
13
|
+
ChoiceLogprobs as OpenAIChunkChoiceLogprobs,
|
8
14
|
)
|
9
15
|
from openai.types.chat.chat_completion_token_logprob import (
|
10
|
-
ChatCompletionTokenLogprob as
|
16
|
+
ChatCompletionTokenLogprob as OpenAITokenLogprob,
|
11
17
|
)
|
12
18
|
from pydantic import BaseModel, Field
|
13
19
|
|
14
20
|
from ..errors import CombineCompletionChunksError
|
15
|
-
from .completion import
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
+
from .completion import Completion, CompletionChoice, FinishReason, Usage
|
22
|
+
from .message import (
|
23
|
+
AssistantMessage,
|
24
|
+
RedactedThinkingBlock,
|
25
|
+
Role,
|
26
|
+
ThinkingBlock,
|
27
|
+
ToolCall,
|
21
28
|
)
|
22
|
-
from .message import AssistantMessage, ToolCall
|
23
29
|
|
24
30
|
|
25
31
|
class CompletionChunkDeltaToolCall(BaseModel):
|
@@ -31,26 +37,34 @@ class CompletionChunkDeltaToolCall(BaseModel):
|
|
31
37
|
|
32
38
|
class CompletionChunkChoiceDelta(BaseModel):
|
33
39
|
content: str | None = None
|
34
|
-
refusal: str | None
|
35
|
-
role:
|
40
|
+
refusal: str | None = None
|
41
|
+
role: Role | None
|
36
42
|
tool_calls: list[CompletionChunkDeltaToolCall] | None
|
43
|
+
reasoning_content: str | None = None
|
44
|
+
thinking_blocks: list[ThinkingBlock | RedactedThinkingBlock] | None = None
|
45
|
+
annotations: list[LiteLLMAnnotation] | None = None
|
46
|
+
provider_specific_fields: dict[str, Any] | None = None
|
37
47
|
|
38
48
|
|
39
49
|
class CompletionChunkChoice(BaseModel):
|
40
50
|
delta: CompletionChunkChoiceDelta
|
41
51
|
finish_reason: FinishReason | None
|
42
52
|
index: int
|
43
|
-
logprobs:
|
53
|
+
logprobs: OpenAIChunkChoiceLogprobs | LiteLLMChoiceLogprobs | Any | None = None
|
44
54
|
|
45
55
|
|
46
56
|
class CompletionChunk(BaseModel):
|
47
57
|
id: str = Field(default_factory=lambda: str(uuid4())[:8])
|
48
58
|
created: int = Field(default_factory=lambda: int(time.time()))
|
49
|
-
model: str
|
59
|
+
model: str | None
|
50
60
|
name: str | None = None
|
51
61
|
system_fingerprint: str | None = None
|
52
62
|
choices: list[CompletionChunkChoice]
|
53
63
|
usage: Usage | None = None
|
64
|
+
# LiteLLM-specific fields
|
65
|
+
provider_specific_fields: dict[str, Any] | None = None
|
66
|
+
response_ms: float | None = None
|
67
|
+
hidden_params: dict[str, Any] | None = None
|
54
68
|
|
55
69
|
|
56
70
|
def combine_completion_chunks(chunks: list[CompletionChunk]) -> Completion:
|
@@ -82,14 +96,20 @@ def combine_completion_chunks(chunks: list[CompletionChunk]) -> Completion:
|
|
82
96
|
# Usage is found in the last completion chunk if requested
|
83
97
|
usage = chunks[-1].usage
|
84
98
|
|
85
|
-
logp_contents_per_choice: defaultdict[int, list[
|
86
|
-
|
99
|
+
logp_contents_per_choice: defaultdict[int, list[OpenAITokenLogprob]] = defaultdict(
|
100
|
+
list
|
101
|
+
)
|
102
|
+
logp_refusals_per_choice: defaultdict[int, list[OpenAITokenLogprob]] = defaultdict(
|
103
|
+
list
|
87
104
|
)
|
88
|
-
|
89
|
-
|
105
|
+
logprobs_per_choice: defaultdict[int, OpenAIChoiceLogprobs | None] = defaultdict(
|
106
|
+
lambda: None
|
90
107
|
)
|
91
|
-
|
92
|
-
|
108
|
+
thinking_blocks_per_choice: defaultdict[
|
109
|
+
int, list[ThinkingBlock | RedactedThinkingBlock]
|
110
|
+
] = defaultdict(list)
|
111
|
+
annotations_per_choice: defaultdict[int, list[LiteLLMAnnotation]] = defaultdict(
|
112
|
+
list
|
93
113
|
)
|
94
114
|
|
95
115
|
finish_reasons_per_choice: defaultdict[int, FinishReason | None] = defaultdict(
|
@@ -97,6 +117,7 @@ def combine_completion_chunks(chunks: list[CompletionChunk]) -> Completion:
|
|
97
117
|
)
|
98
118
|
|
99
119
|
contents_per_choice: defaultdict[int, str] = defaultdict(lambda: "")
|
120
|
+
reasoning_contents_per_choice: defaultdict[int, str] = defaultdict(lambda: "")
|
100
121
|
refusals_per_choice: defaultdict[int, str] = defaultdict(lambda: "")
|
101
122
|
|
102
123
|
tool_calls_per_choice: defaultdict[
|
@@ -111,12 +132,17 @@ def combine_completion_chunks(chunks: list[CompletionChunk]) -> Completion:
|
|
111
132
|
|
112
133
|
# Concatenate content and refusal tokens for each choice
|
113
134
|
contents_per_choice[index] += choice.delta.content or ""
|
135
|
+
reasoning_contents_per_choice[index] += choice.delta.reasoning_content or ""
|
114
136
|
refusals_per_choice[index] += choice.delta.refusal or ""
|
115
137
|
|
116
138
|
# Concatenate logprobs for content and refusal tokens for each choice
|
117
139
|
if choice.logprobs is not None:
|
118
|
-
logp_contents_per_choice[index].extend(choice.logprobs.content or [])
|
119
|
-
logp_refusals_per_choice[index].extend(choice.logprobs.refusal or [])
|
140
|
+
logp_contents_per_choice[index].extend(choice.logprobs.content or []) # type: ignore
|
141
|
+
logp_refusals_per_choice[index].extend(choice.logprobs.refusal or []) # type: ignore
|
142
|
+
thinking_blocks_per_choice[index].extend(
|
143
|
+
choice.delta.thinking_blocks or []
|
144
|
+
)
|
145
|
+
annotations_per_choice[index].extend(choice.delta.annotations or [])
|
120
146
|
|
121
147
|
# Take the last finish reason for each choice
|
122
148
|
finish_reasons_per_choice[index] = choice.finish_reason
|
@@ -148,12 +174,15 @@ def combine_completion_chunks(chunks: list[CompletionChunk]) -> Completion:
|
|
148
174
|
messages_per_choice[index] = AssistantMessage(
|
149
175
|
name=name,
|
150
176
|
content=contents_per_choice[index] or "<empty>",
|
177
|
+
reasoning_content=(reasoning_contents_per_choice[index] or None),
|
178
|
+
thinking_blocks=(thinking_blocks_per_choice[index] or None),
|
179
|
+
annotations=(annotations_per_choice[index] or None),
|
151
180
|
refusal=(refusals_per_choice[index] or None),
|
152
181
|
tool_calls=(tool_calls or None),
|
153
182
|
)
|
154
183
|
|
155
184
|
if logp_contents_per_choice[index] or logp_refusals_per_choice[index]:
|
156
|
-
logprobs_per_choice[index] =
|
185
|
+
logprobs_per_choice[index] = OpenAIChoiceLogprobs(
|
157
186
|
content=logp_contents_per_choice[index],
|
158
187
|
refusal=logp_refusals_per_choice[index],
|
159
188
|
)
|
grasp_agents/typing/events.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
import time
|
2
2
|
from enum import StrEnum
|
3
3
|
from typing import Any, Generic, Literal, TypeVar
|
4
|
+
from uuid import uuid4
|
4
5
|
|
5
|
-
from pydantic import BaseModel, Field
|
6
|
+
from pydantic import BaseModel, ConfigDict, Field
|
6
7
|
|
7
8
|
from ..packet import Packet
|
8
9
|
from .completion import Completion
|
@@ -15,7 +16,9 @@ class EventSourceType(StrEnum):
|
|
15
16
|
AGENT = "agent"
|
16
17
|
USER = "user"
|
17
18
|
TOOL = "tool"
|
18
|
-
|
19
|
+
PROC = "processor"
|
20
|
+
WORKFLOW = "workflow"
|
21
|
+
RUN = "run"
|
19
22
|
|
20
23
|
|
21
24
|
class EventType(StrEnum):
|
@@ -24,10 +27,22 @@ class EventType(StrEnum):
|
|
24
27
|
TOOL_MSG = "tool_message"
|
25
28
|
TOOL_CALL = "tool_call"
|
26
29
|
GEN_MSG = "gen_message"
|
30
|
+
|
27
31
|
COMP = "completion"
|
28
32
|
COMP_CHUNK = "completion_chunk"
|
29
|
-
|
30
|
-
|
33
|
+
LLM_ERR = "llm_error"
|
34
|
+
|
35
|
+
PROC_START = "processor_start"
|
36
|
+
PACKET_OUT = "packet_output"
|
37
|
+
PAYLOAD_OUT = "payload_output"
|
38
|
+
PROC_FINISH = "processor_finish"
|
39
|
+
PROC_ERR = "processor_error"
|
40
|
+
|
41
|
+
WORKFLOW_RES = "workflow_result"
|
42
|
+
RUN_RES = "run_result"
|
43
|
+
|
44
|
+
# COMP_THINK_CHUNK = "completion_thinking_chunk"
|
45
|
+
# COMP_RESP_CHUNK = "completion_response_chunk"
|
31
46
|
|
32
47
|
|
33
48
|
_T = TypeVar("_T")
|
@@ -36,8 +51,10 @@ _T = TypeVar("_T")
|
|
36
51
|
class Event(BaseModel, Generic[_T], frozen=True):
|
37
52
|
type: EventType
|
38
53
|
source: EventSourceType
|
54
|
+
id: str = Field(default_factory=lambda: str(uuid4()))
|
39
55
|
created: int = Field(default_factory=lambda: int(time.time()))
|
40
|
-
|
56
|
+
proc_name: str | None = None
|
57
|
+
call_id: str | None = None
|
41
58
|
data: _T
|
42
59
|
|
43
60
|
|
@@ -51,36 +68,95 @@ class CompletionChunkEvent(Event[CompletionChunk], frozen=True):
|
|
51
68
|
source: Literal[EventSourceType.LLM] = EventSourceType.LLM
|
52
69
|
|
53
70
|
|
54
|
-
class
|
55
|
-
|
71
|
+
class LLMStreamingErrorData(BaseModel):
|
72
|
+
error: Exception
|
73
|
+
model_name: str | None = None
|
74
|
+
model_id: str | None = None
|
75
|
+
|
76
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
77
|
+
|
78
|
+
|
79
|
+
class LLMStreamingErrorEvent(Event[LLMStreamingErrorData], frozen=True):
|
80
|
+
type: Literal[EventType.LLM_ERR] = EventType.LLM_ERR
|
56
81
|
source: Literal[EventSourceType.LLM] = EventSourceType.LLM
|
57
82
|
|
58
83
|
|
59
|
-
class
|
60
|
-
|
61
|
-
|
84
|
+
# class CompletionThinkingChunkEvent(Event[CompletionChunk], frozen=True):
|
85
|
+
# type: Literal[EventType.COMP_THINK_CHUNK] = EventType.COMP_THINK_CHUNK
|
86
|
+
# source: Literal[EventSourceType.LLM] = EventSourceType.LLM
|
87
|
+
|
88
|
+
|
89
|
+
# class CompletionResponseChunkEvent(Event[CompletionChunk], frozen=True):
|
90
|
+
# type: Literal[EventType.COMP_RESP_CHUNK] = EventType.COMP_RESP_CHUNK
|
91
|
+
# source: Literal[EventSourceType.LLM] = EventSourceType.LLM
|
92
|
+
|
93
|
+
|
94
|
+
class MessageEvent(Event[_T], Generic[_T], frozen=True):
|
95
|
+
pass
|
96
|
+
|
97
|
+
|
98
|
+
class GenMessageEvent(MessageEvent[AssistantMessage], frozen=True):
|
99
|
+
type: Literal[EventType.GEN_MSG] = EventType.GEN_MSG
|
100
|
+
source: Literal[EventSourceType.LLM] = EventSourceType.LLM
|
62
101
|
|
63
102
|
|
64
|
-
class ToolMessageEvent(
|
103
|
+
class ToolMessageEvent(MessageEvent[ToolMessage], frozen=True):
|
65
104
|
type: Literal[EventType.TOOL_MSG] = EventType.TOOL_MSG
|
66
105
|
source: Literal[EventSourceType.TOOL] = EventSourceType.TOOL
|
67
106
|
|
68
107
|
|
69
|
-
class UserMessageEvent(
|
108
|
+
class UserMessageEvent(MessageEvent[UserMessage], frozen=True):
|
70
109
|
type: Literal[EventType.USR_MSG] = EventType.USR_MSG
|
71
110
|
source: Literal[EventSourceType.USER] = EventSourceType.USER
|
72
111
|
|
73
112
|
|
74
|
-
class SystemMessageEvent(
|
113
|
+
class SystemMessageEvent(MessageEvent[SystemMessage], frozen=True):
|
75
114
|
type: Literal[EventType.SYS_MSG] = EventType.SYS_MSG
|
76
115
|
source: Literal[EventSourceType.AGENT] = EventSourceType.AGENT
|
77
116
|
|
78
117
|
|
79
|
-
class
|
80
|
-
type: Literal[EventType.
|
81
|
-
source: Literal[EventSourceType.
|
118
|
+
class ToolCallEvent(Event[ToolCall], frozen=True):
|
119
|
+
type: Literal[EventType.TOOL_CALL] = EventType.TOOL_CALL
|
120
|
+
source: Literal[EventSourceType.AGENT] = EventSourceType.AGENT
|
121
|
+
|
122
|
+
|
123
|
+
class ProcStartEvent(Event[None], frozen=True):
|
124
|
+
type: Literal[EventType.PROC_START] = EventType.PROC_START
|
125
|
+
source: Literal[EventSourceType.PROC] = EventSourceType.PROC
|
126
|
+
|
127
|
+
|
128
|
+
class ProcFinishEvent(Event[None], frozen=True):
|
129
|
+
type: Literal[EventType.PROC_FINISH] = EventType.PROC_FINISH
|
130
|
+
source: Literal[EventSourceType.PROC] = EventSourceType.PROC
|
131
|
+
|
132
|
+
|
133
|
+
class ProcPayloadOutputEvent(Event[Any], frozen=True):
|
134
|
+
type: Literal[EventType.PAYLOAD_OUT] = EventType.PAYLOAD_OUT
|
135
|
+
source: Literal[EventSourceType.PROC] = EventSourceType.PROC
|
136
|
+
|
137
|
+
|
138
|
+
class ProcPacketOutputEvent(Event[Packet[Any]], frozen=True):
|
139
|
+
type: Literal[EventType.PACKET_OUT] = EventType.PACKET_OUT
|
140
|
+
source: Literal[EventSourceType.PROC] = EventSourceType.PROC
|
141
|
+
|
142
|
+
|
143
|
+
class WorkflowResultEvent(Event[Packet[Any]], frozen=True):
|
144
|
+
type: Literal[EventType.WORKFLOW_RES] = EventType.WORKFLOW_RES
|
145
|
+
source: Literal[EventSourceType.WORKFLOW] = EventSourceType.WORKFLOW
|
146
|
+
|
147
|
+
|
148
|
+
class RunResultEvent(Event[Packet[Any]], frozen=True):
|
149
|
+
type: Literal[EventType.RUN_RES] = EventType.RUN_RES
|
150
|
+
source: Literal[EventSourceType.RUN] = EventSourceType.RUN
|
151
|
+
|
152
|
+
|
153
|
+
class ProcStreamingErrorData(BaseModel):
|
154
|
+
error: Exception
|
155
|
+
call_id: str | None = None
|
156
|
+
|
157
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
82
158
|
|
83
159
|
|
84
|
-
class
|
85
|
-
type: Literal[EventType.
|
86
|
-
source: Literal[EventSourceType.
|
160
|
+
class ProcStreamingErrorEvent(Event[ProcStreamingErrorData], frozen=True):
|
161
|
+
type: Literal[EventType.PROC_ERR] = EventType.PROC_ERR
|
162
|
+
source: Literal[EventSourceType.PROC] = EventSourceType.PROC
|
grasp_agents/typing/message.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1
1
|
import json
|
2
2
|
from collections.abc import Hashable, Mapping, Sequence
|
3
3
|
from enum import StrEnum
|
4
|
-
from typing import Annotated, Any, Literal, TypeAlias
|
4
|
+
from typing import Annotated, Any, Literal, Required, TypeAlias
|
5
5
|
from uuid import uuid4
|
6
6
|
|
7
|
+
from litellm.types.llms.openai import ChatCompletionAnnotation as LiteLLMAnnotation
|
7
8
|
from pydantic import BaseModel, Field
|
8
9
|
from pydantic.json import pydantic_encoder
|
10
|
+
from typing_extensions import TypedDict
|
9
11
|
|
10
12
|
from .content import Content, ImageData
|
11
13
|
from .tool import ToolCall
|
@@ -14,6 +16,7 @@ from .tool import ToolCall
|
|
14
16
|
class Role(StrEnum):
|
15
17
|
USER = "user"
|
16
18
|
SYSTEM = "system"
|
19
|
+
DEVELOPER = "developer"
|
17
20
|
ASSISTANT = "assistant"
|
18
21
|
TOOL = "tool"
|
19
22
|
|
@@ -23,11 +26,32 @@ class MessageBase(BaseModel):
|
|
23
26
|
name: str | None = None
|
24
27
|
|
25
28
|
|
29
|
+
class ChatCompletionCachedContent(TypedDict):
|
30
|
+
type: Literal["ephemeral"]
|
31
|
+
|
32
|
+
|
33
|
+
class ThinkingBlock(TypedDict, total=False):
|
34
|
+
type: Required[Literal["thinking"]]
|
35
|
+
thinking: str
|
36
|
+
signature: str | None
|
37
|
+
cache_control: dict[str, Any] | ChatCompletionCachedContent | None
|
38
|
+
|
39
|
+
|
40
|
+
class RedactedThinkingBlock(TypedDict, total=False):
|
41
|
+
type: Required[Literal["redacted_thinking"]]
|
42
|
+
data: str
|
43
|
+
cache_control: dict[str, Any] | ChatCompletionCachedContent | None
|
44
|
+
|
45
|
+
|
26
46
|
class AssistantMessage(MessageBase):
|
27
47
|
role: Literal[Role.ASSISTANT] = Role.ASSISTANT
|
28
48
|
content: str | None
|
29
49
|
tool_calls: Sequence[ToolCall] | None = None
|
30
50
|
refusal: str | None = None
|
51
|
+
reasoning_content: str | None = None
|
52
|
+
thinking_blocks: Sequence[ThinkingBlock | RedactedThinkingBlock] | None = None
|
53
|
+
annotations: Sequence[LiteLLMAnnotation] | None = None
|
54
|
+
provider_specific_fields: dict[str, Any] | None = None
|
31
55
|
|
32
56
|
|
33
57
|
class UserMessage(MessageBase):
|
grasp_agents/typing/tool.py
CHANGED
grasp_agents/usage_tracker.py
CHANGED
@@ -20,7 +20,6 @@ CostsDict: TypeAlias = dict[str, ModelCostsDict]
|
|
20
20
|
|
21
21
|
|
22
22
|
class UsageTracker(BaseModel):
|
23
|
-
# TODO: specify different costs per provider:model, not just per model
|
24
23
|
costs_dict_path: str | Path = COSTS_DICT_PATH
|
25
24
|
costs_dict: CostsDict | None = None
|
26
25
|
usages: dict[str, Usage] = Field(default_factory=dict)
|
@@ -29,34 +28,6 @@ class UsageTracker(BaseModel):
|
|
29
28
|
super().__init__(**kwargs)
|
30
29
|
self.costs_dict = self.load_costs_dict()
|
31
30
|
|
32
|
-
def load_costs_dict(self) -> CostsDict | None:
|
33
|
-
try:
|
34
|
-
with Path(self.costs_dict_path).open() as f:
|
35
|
-
return yaml.safe_load(f)["costs"]
|
36
|
-
except Exception:
|
37
|
-
logger.info(f"Failed to load cost dictionary from {self.costs_dict_path}")
|
38
|
-
return None
|
39
|
-
|
40
|
-
def _add_cost_to_usage(
|
41
|
-
self, usage: Usage, model_costs_dict: ModelCostsDict
|
42
|
-
) -> None:
|
43
|
-
in_rate = model_costs_dict["input"]
|
44
|
-
out_rate = model_costs_dict["output"]
|
45
|
-
cached_discount = model_costs_dict.get("cached_discount")
|
46
|
-
input_cost = in_rate * usage.input_tokens
|
47
|
-
output_cost = out_rate * usage.output_tokens
|
48
|
-
reasoning_cost = (
|
49
|
-
out_rate * usage.reasoning_tokens
|
50
|
-
if usage.reasoning_tokens is not None
|
51
|
-
else 0.0
|
52
|
-
)
|
53
|
-
cached_cost: float = (
|
54
|
-
cached_discount * in_rate * usage.cached_tokens
|
55
|
-
if (usage.cached_tokens is not None) and (cached_discount is not None)
|
56
|
-
else 0.0
|
57
|
-
)
|
58
|
-
usage.cost = (input_cost + output_cost + reasoning_cost + cached_cost) / 1e6
|
59
|
-
|
60
31
|
def update(
|
61
32
|
self,
|
62
33
|
agent_name: str,
|
@@ -64,13 +35,13 @@ class UsageTracker(BaseModel):
|
|
64
35
|
model_name: str | None = None,
|
65
36
|
) -> None:
|
66
37
|
if model_name is not None and self.costs_dict is not None:
|
67
|
-
model_costs_dict = self.costs_dict.get(model_name.split("
|
38
|
+
model_costs_dict = self.costs_dict.get(model_name.split("/", 1)[-1])
|
68
39
|
else:
|
69
40
|
model_costs_dict = None
|
70
41
|
|
71
42
|
for completion in completions:
|
72
43
|
if completion.usage is not None:
|
73
|
-
if model_costs_dict is not None:
|
44
|
+
if completion.usage.cost is None and model_costs_dict is not None:
|
74
45
|
self._add_cost_to_usage(
|
75
46
|
usage=completion.usage, model_costs_dict=model_costs_dict
|
76
47
|
)
|
@@ -100,9 +71,32 @@ class UsageTracker(BaseModel):
|
|
100
71
|
logger.debug(colored(token_usage_str, "light_grey"))
|
101
72
|
|
102
73
|
if usage.cost is not None:
|
103
|
-
logger.debug(
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
74
|
+
logger.debug(colored(f"Total cost: ${usage.cost:.4f}", "light_grey"))
|
75
|
+
|
76
|
+
def load_costs_dict(self) -> CostsDict | None:
|
77
|
+
try:
|
78
|
+
with Path(self.costs_dict_path).open() as f:
|
79
|
+
return yaml.safe_load(f)["costs"]
|
80
|
+
except Exception:
|
81
|
+
logger.info(f"Failed to load cost dictionary from {self.costs_dict_path}")
|
82
|
+
return None
|
83
|
+
|
84
|
+
def _add_cost_to_usage(
|
85
|
+
self, usage: Usage, model_costs_dict: ModelCostsDict
|
86
|
+
) -> None:
|
87
|
+
in_rate = model_costs_dict["input"]
|
88
|
+
out_rate = model_costs_dict["output"]
|
89
|
+
cached_discount = model_costs_dict.get("cached_discount")
|
90
|
+
input_cost = in_rate * usage.input_tokens
|
91
|
+
output_cost = out_rate * usage.output_tokens
|
92
|
+
reasoning_cost = (
|
93
|
+
out_rate * usage.reasoning_tokens
|
94
|
+
if usage.reasoning_tokens is not None
|
95
|
+
else 0.0
|
96
|
+
)
|
97
|
+
cached_cost: float = (
|
98
|
+
cached_discount * in_rate * usage.cached_tokens
|
99
|
+
if (usage.cached_tokens is not None) and (cached_discount is not None)
|
100
|
+
else 0.0
|
101
|
+
)
|
102
|
+
usage.cost = (input_cost + output_cost + reasoning_cost + cached_cost) / 1e6
|