grasp_agents 0.4.7__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. grasp_agents/cloud_llm.py +191 -224
  2. grasp_agents/comm_processor.py +101 -100
  3. grasp_agents/errors.py +69 -9
  4. grasp_agents/litellm/__init__.py +106 -0
  5. grasp_agents/litellm/completion_chunk_converters.py +68 -0
  6. grasp_agents/litellm/completion_converters.py +72 -0
  7. grasp_agents/litellm/converters.py +138 -0
  8. grasp_agents/litellm/lite_llm.py +210 -0
  9. grasp_agents/litellm/message_converters.py +66 -0
  10. grasp_agents/llm.py +84 -49
  11. grasp_agents/llm_agent.py +136 -120
  12. grasp_agents/llm_agent_memory.py +3 -3
  13. grasp_agents/llm_policy_executor.py +167 -174
  14. grasp_agents/memory.py +4 -0
  15. grasp_agents/openai/__init__.py +24 -9
  16. grasp_agents/openai/completion_chunk_converters.py +6 -6
  17. grasp_agents/openai/completion_converters.py +12 -14
  18. grasp_agents/openai/content_converters.py +1 -3
  19. grasp_agents/openai/converters.py +6 -8
  20. grasp_agents/openai/message_converters.py +21 -3
  21. grasp_agents/openai/openai_llm.py +155 -103
  22. grasp_agents/openai/tool_converters.py +4 -6
  23. grasp_agents/packet.py +5 -2
  24. grasp_agents/packet_pool.py +14 -13
  25. grasp_agents/printer.py +234 -72
  26. grasp_agents/processor.py +228 -88
  27. grasp_agents/prompt_builder.py +2 -2
  28. grasp_agents/run_context.py +11 -20
  29. grasp_agents/runner.py +42 -0
  30. grasp_agents/typing/completion.py +16 -9
  31. grasp_agents/typing/completion_chunk.py +51 -22
  32. grasp_agents/typing/events.py +95 -19
  33. grasp_agents/typing/message.py +25 -1
  34. grasp_agents/typing/tool.py +2 -0
  35. grasp_agents/usage_tracker.py +31 -37
  36. grasp_agents/utils.py +95 -84
  37. grasp_agents/workflow/looped_workflow.py +60 -11
  38. grasp_agents/workflow/sequential_workflow.py +43 -11
  39. grasp_agents/workflow/workflow_processor.py +25 -24
  40. {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.0.dist-info}/METADATA +7 -6
  41. grasp_agents-0.5.0.dist-info/RECORD +57 -0
  42. grasp_agents-0.4.7.dist-info/RECORD +0 -50
  43. {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.0.dist-info}/WHEEL +0 -0
  44. {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.0.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,25 +1,31 @@
1
1
  import time
2
2
  from collections import defaultdict
3
3
  from collections.abc import Sequence
4
+ from typing import Any
4
5
  from uuid import uuid4
5
6
 
7
+ from litellm import ChatCompletionAnnotation as LiteLLMAnnotation
8
+ from litellm.types.utils import ChoiceLogprobs as LiteLLMChoiceLogprobs
9
+ from openai.types.chat.chat_completion import (
10
+ ChoiceLogprobs as OpenAIChoiceLogprobs,
11
+ )
6
12
  from openai.types.chat.chat_completion_chunk import (
7
- ChoiceLogprobs as CompletionChunkChoiceLogprobs,
13
+ ChoiceLogprobs as OpenAIChunkChoiceLogprobs,
8
14
  )
9
15
  from openai.types.chat.chat_completion_token_logprob import (
10
- ChatCompletionTokenLogprob as CompletionTokenLogprob,
16
+ ChatCompletionTokenLogprob as OpenAITokenLogprob,
11
17
  )
12
18
  from pydantic import BaseModel, Field
13
19
 
14
20
  from ..errors import CombineCompletionChunksError
15
- from .completion import (
16
- Completion,
17
- CompletionChoice,
18
- CompletionChoiceLogprobs,
19
- FinishReason,
20
- Usage,
21
+ from .completion import Completion, CompletionChoice, FinishReason, Usage
22
+ from .message import (
23
+ AssistantMessage,
24
+ RedactedThinkingBlock,
25
+ Role,
26
+ ThinkingBlock,
27
+ ToolCall,
21
28
  )
22
- from .message import AssistantMessage, ToolCall
23
29
 
24
30
 
25
31
  class CompletionChunkDeltaToolCall(BaseModel):
@@ -31,26 +37,34 @@ class CompletionChunkDeltaToolCall(BaseModel):
31
37
 
32
38
  class CompletionChunkChoiceDelta(BaseModel):
33
39
  content: str | None = None
34
- refusal: str | None
35
- role: str | None
40
+ refusal: str | None = None
41
+ role: Role | None
36
42
  tool_calls: list[CompletionChunkDeltaToolCall] | None
43
+ reasoning_content: str | None = None
44
+ thinking_blocks: list[ThinkingBlock | RedactedThinkingBlock] | None = None
45
+ annotations: list[LiteLLMAnnotation] | None = None
46
+ provider_specific_fields: dict[str, Any] | None = None
37
47
 
38
48
 
39
49
  class CompletionChunkChoice(BaseModel):
40
50
  delta: CompletionChunkChoiceDelta
41
51
  finish_reason: FinishReason | None
42
52
  index: int
43
- logprobs: CompletionChunkChoiceLogprobs | None = None
53
+ logprobs: OpenAIChunkChoiceLogprobs | LiteLLMChoiceLogprobs | Any | None = None
44
54
 
45
55
 
46
56
  class CompletionChunk(BaseModel):
47
57
  id: str = Field(default_factory=lambda: str(uuid4())[:8])
48
58
  created: int = Field(default_factory=lambda: int(time.time()))
49
- model: str
59
+ model: str | None
50
60
  name: str | None = None
51
61
  system_fingerprint: str | None = None
52
62
  choices: list[CompletionChunkChoice]
53
63
  usage: Usage | None = None
64
+ # LiteLLM-specific fields
65
+ provider_specific_fields: dict[str, Any] | None = None
66
+ response_ms: float | None = None
67
+ hidden_params: dict[str, Any] | None = None
54
68
 
55
69
 
56
70
  def combine_completion_chunks(chunks: list[CompletionChunk]) -> Completion:
@@ -82,14 +96,20 @@ def combine_completion_chunks(chunks: list[CompletionChunk]) -> Completion:
82
96
  # Usage is found in the last completion chunk if requested
83
97
  usage = chunks[-1].usage
84
98
 
85
- logp_contents_per_choice: defaultdict[int, list[CompletionTokenLogprob]] = (
86
- defaultdict(list)
99
+ logp_contents_per_choice: defaultdict[int, list[OpenAITokenLogprob]] = defaultdict(
100
+ list
101
+ )
102
+ logp_refusals_per_choice: defaultdict[int, list[OpenAITokenLogprob]] = defaultdict(
103
+ list
87
104
  )
88
- logp_refusals_per_choice: defaultdict[int, list[CompletionTokenLogprob]] = (
89
- defaultdict(list)
105
+ logprobs_per_choice: defaultdict[int, OpenAIChoiceLogprobs | None] = defaultdict(
106
+ lambda: None
90
107
  )
91
- logprobs_per_choice: defaultdict[int, CompletionChoiceLogprobs | None] = (
92
- defaultdict(lambda: None)
108
+ thinking_blocks_per_choice: defaultdict[
109
+ int, list[ThinkingBlock | RedactedThinkingBlock]
110
+ ] = defaultdict(list)
111
+ annotations_per_choice: defaultdict[int, list[LiteLLMAnnotation]] = defaultdict(
112
+ list
93
113
  )
94
114
 
95
115
  finish_reasons_per_choice: defaultdict[int, FinishReason | None] = defaultdict(
@@ -97,6 +117,7 @@ def combine_completion_chunks(chunks: list[CompletionChunk]) -> Completion:
97
117
  )
98
118
 
99
119
  contents_per_choice: defaultdict[int, str] = defaultdict(lambda: "")
120
+ reasoning_contents_per_choice: defaultdict[int, str] = defaultdict(lambda: "")
100
121
  refusals_per_choice: defaultdict[int, str] = defaultdict(lambda: "")
101
122
 
102
123
  tool_calls_per_choice: defaultdict[
@@ -111,12 +132,17 @@ def combine_completion_chunks(chunks: list[CompletionChunk]) -> Completion:
111
132
 
112
133
  # Concatenate content and refusal tokens for each choice
113
134
  contents_per_choice[index] += choice.delta.content or ""
135
+ reasoning_contents_per_choice[index] += choice.delta.reasoning_content or ""
114
136
  refusals_per_choice[index] += choice.delta.refusal or ""
115
137
 
116
138
  # Concatenate logprobs for content and refusal tokens for each choice
117
139
  if choice.logprobs is not None:
118
- logp_contents_per_choice[index].extend(choice.logprobs.content or [])
119
- logp_refusals_per_choice[index].extend(choice.logprobs.refusal or [])
140
+ logp_contents_per_choice[index].extend(choice.logprobs.content or []) # type: ignore
141
+ logp_refusals_per_choice[index].extend(choice.logprobs.refusal or []) # type: ignore
142
+ thinking_blocks_per_choice[index].extend(
143
+ choice.delta.thinking_blocks or []
144
+ )
145
+ annotations_per_choice[index].extend(choice.delta.annotations or [])
120
146
 
121
147
  # Take the last finish reason for each choice
122
148
  finish_reasons_per_choice[index] = choice.finish_reason
@@ -148,12 +174,15 @@ def combine_completion_chunks(chunks: list[CompletionChunk]) -> Completion:
148
174
  messages_per_choice[index] = AssistantMessage(
149
175
  name=name,
150
176
  content=contents_per_choice[index] or "<empty>",
177
+ reasoning_content=(reasoning_contents_per_choice[index] or None),
178
+ thinking_blocks=(thinking_blocks_per_choice[index] or None),
179
+ annotations=(annotations_per_choice[index] or None),
151
180
  refusal=(refusals_per_choice[index] or None),
152
181
  tool_calls=(tool_calls or None),
153
182
  )
154
183
 
155
184
  if logp_contents_per_choice[index] or logp_refusals_per_choice[index]:
156
- logprobs_per_choice[index] = CompletionChoiceLogprobs(
185
+ logprobs_per_choice[index] = OpenAIChoiceLogprobs(
157
186
  content=logp_contents_per_choice[index],
158
187
  refusal=logp_refusals_per_choice[index],
159
188
  )
@@ -1,8 +1,9 @@
1
1
  import time
2
2
  from enum import StrEnum
3
3
  from typing import Any, Generic, Literal, TypeVar
4
+ from uuid import uuid4
4
5
 
5
- from pydantic import BaseModel, Field
6
+ from pydantic import BaseModel, ConfigDict, Field
6
7
 
7
8
  from ..packet import Packet
8
9
  from .completion import Completion
@@ -15,7 +16,9 @@ class EventSourceType(StrEnum):
15
16
  AGENT = "agent"
16
17
  USER = "user"
17
18
  TOOL = "tool"
18
- PROCESSOR = "processor"
19
+ PROC = "processor"
20
+ WORKFLOW = "workflow"
21
+ RUN = "run"
19
22
 
20
23
 
21
24
  class EventType(StrEnum):
@@ -24,10 +27,22 @@ class EventType(StrEnum):
24
27
  TOOL_MSG = "tool_message"
25
28
  TOOL_CALL = "tool_call"
26
29
  GEN_MSG = "gen_message"
30
+
27
31
  COMP = "completion"
28
32
  COMP_CHUNK = "completion_chunk"
29
- PACKET = "packet"
30
- PROC_OUT = "processor_output"
33
+ LLM_ERR = "llm_error"
34
+
35
+ PROC_START = "processor_start"
36
+ PACKET_OUT = "packet_output"
37
+ PAYLOAD_OUT = "payload_output"
38
+ PROC_FINISH = "processor_finish"
39
+ PROC_ERR = "processor_error"
40
+
41
+ WORKFLOW_RES = "workflow_result"
42
+ RUN_RES = "run_result"
43
+
44
+ # COMP_THINK_CHUNK = "completion_thinking_chunk"
45
+ # COMP_RESP_CHUNK = "completion_response_chunk"
31
46
 
32
47
 
33
48
  _T = TypeVar("_T")
@@ -36,8 +51,10 @@ _T = TypeVar("_T")
36
51
  class Event(BaseModel, Generic[_T], frozen=True):
37
52
  type: EventType
38
53
  source: EventSourceType
54
+ id: str = Field(default_factory=lambda: str(uuid4()))
39
55
  created: int = Field(default_factory=lambda: int(time.time()))
40
- name: str | None = None
56
+ proc_name: str | None = None
57
+ call_id: str | None = None
41
58
  data: _T
42
59
 
43
60
 
@@ -51,36 +68,95 @@ class CompletionChunkEvent(Event[CompletionChunk], frozen=True):
51
68
  source: Literal[EventSourceType.LLM] = EventSourceType.LLM
52
69
 
53
70
 
54
- class GenMessageEvent(Event[AssistantMessage], frozen=True):
55
- type: Literal[EventType.GEN_MSG] = EventType.GEN_MSG
71
+ class LLMStreamingErrorData(BaseModel):
72
+ error: Exception
73
+ model_name: str | None = None
74
+ model_id: str | None = None
75
+
76
+ model_config = ConfigDict(arbitrary_types_allowed=True)
77
+
78
+
79
+ class LLMStreamingErrorEvent(Event[LLMStreamingErrorData], frozen=True):
80
+ type: Literal[EventType.LLM_ERR] = EventType.LLM_ERR
56
81
  source: Literal[EventSourceType.LLM] = EventSourceType.LLM
57
82
 
58
83
 
59
- class ToolCallEvent(Event[ToolCall], frozen=True):
60
- type: Literal[EventType.TOOL_CALL] = EventType.TOOL_CALL
61
- source: Literal[EventSourceType.AGENT] = EventSourceType.AGENT
84
+ # class CompletionThinkingChunkEvent(Event[CompletionChunk], frozen=True):
85
+ # type: Literal[EventType.COMP_THINK_CHUNK] = EventType.COMP_THINK_CHUNK
86
+ # source: Literal[EventSourceType.LLM] = EventSourceType.LLM
87
+
88
+
89
+ # class CompletionResponseChunkEvent(Event[CompletionChunk], frozen=True):
90
+ # type: Literal[EventType.COMP_RESP_CHUNK] = EventType.COMP_RESP_CHUNK
91
+ # source: Literal[EventSourceType.LLM] = EventSourceType.LLM
92
+
93
+
94
+ class MessageEvent(Event[_T], Generic[_T], frozen=True):
95
+ pass
96
+
97
+
98
+ class GenMessageEvent(MessageEvent[AssistantMessage], frozen=True):
99
+ type: Literal[EventType.GEN_MSG] = EventType.GEN_MSG
100
+ source: Literal[EventSourceType.LLM] = EventSourceType.LLM
62
101
 
63
102
 
64
- class ToolMessageEvent(Event[ToolMessage], frozen=True):
103
+ class ToolMessageEvent(MessageEvent[ToolMessage], frozen=True):
65
104
  type: Literal[EventType.TOOL_MSG] = EventType.TOOL_MSG
66
105
  source: Literal[EventSourceType.TOOL] = EventSourceType.TOOL
67
106
 
68
107
 
69
- class UserMessageEvent(Event[UserMessage], frozen=True):
108
+ class UserMessageEvent(MessageEvent[UserMessage], frozen=True):
70
109
  type: Literal[EventType.USR_MSG] = EventType.USR_MSG
71
110
  source: Literal[EventSourceType.USER] = EventSourceType.USER
72
111
 
73
112
 
74
- class SystemMessageEvent(Event[SystemMessage], frozen=True):
113
+ class SystemMessageEvent(MessageEvent[SystemMessage], frozen=True):
75
114
  type: Literal[EventType.SYS_MSG] = EventType.SYS_MSG
76
115
  source: Literal[EventSourceType.AGENT] = EventSourceType.AGENT
77
116
 
78
117
 
79
- class PacketEvent(Event[Packet[Any]], frozen=True):
80
- type: Literal[EventType.PACKET] = EventType.PACKET
81
- source: Literal[EventSourceType.PROCESSOR] = EventSourceType.PROCESSOR
118
+ class ToolCallEvent(Event[ToolCall], frozen=True):
119
+ type: Literal[EventType.TOOL_CALL] = EventType.TOOL_CALL
120
+ source: Literal[EventSourceType.AGENT] = EventSourceType.AGENT
121
+
122
+
123
+ class ProcStartEvent(Event[None], frozen=True):
124
+ type: Literal[EventType.PROC_START] = EventType.PROC_START
125
+ source: Literal[EventSourceType.PROC] = EventSourceType.PROC
126
+
127
+
128
+ class ProcFinishEvent(Event[None], frozen=True):
129
+ type: Literal[EventType.PROC_FINISH] = EventType.PROC_FINISH
130
+ source: Literal[EventSourceType.PROC] = EventSourceType.PROC
131
+
132
+
133
+ class ProcPayloadOutputEvent(Event[Any], frozen=True):
134
+ type: Literal[EventType.PAYLOAD_OUT] = EventType.PAYLOAD_OUT
135
+ source: Literal[EventSourceType.PROC] = EventSourceType.PROC
136
+
137
+
138
+ class ProcPacketOutputEvent(Event[Packet[Any]], frozen=True):
139
+ type: Literal[EventType.PACKET_OUT] = EventType.PACKET_OUT
140
+ source: Literal[EventSourceType.PROC] = EventSourceType.PROC
141
+
142
+
143
+ class WorkflowResultEvent(Event[Packet[Any]], frozen=True):
144
+ type: Literal[EventType.WORKFLOW_RES] = EventType.WORKFLOW_RES
145
+ source: Literal[EventSourceType.WORKFLOW] = EventSourceType.WORKFLOW
146
+
147
+
148
+ class RunResultEvent(Event[Packet[Any]], frozen=True):
149
+ type: Literal[EventType.RUN_RES] = EventType.RUN_RES
150
+ source: Literal[EventSourceType.RUN] = EventSourceType.RUN
151
+
152
+
153
+ class ProcStreamingErrorData(BaseModel):
154
+ error: Exception
155
+ call_id: str | None = None
156
+
157
+ model_config = ConfigDict(arbitrary_types_allowed=True)
82
158
 
83
159
 
84
- class ProcOutputEvent(Event[Any], frozen=True):
85
- type: Literal[EventType.PROC_OUT] = EventType.PROC_OUT
86
- source: Literal[EventSourceType.PROCESSOR] = EventSourceType.PROCESSOR
160
+ class ProcStreamingErrorEvent(Event[ProcStreamingErrorData], frozen=True):
161
+ type: Literal[EventType.PROC_ERR] = EventType.PROC_ERR
162
+ source: Literal[EventSourceType.PROC] = EventSourceType.PROC
@@ -1,11 +1,13 @@
1
1
  import json
2
2
  from collections.abc import Hashable, Mapping, Sequence
3
3
  from enum import StrEnum
4
- from typing import Annotated, Any, Literal, TypeAlias
4
+ from typing import Annotated, Any, Literal, Required, TypeAlias
5
5
  from uuid import uuid4
6
6
 
7
+ from litellm.types.llms.openai import ChatCompletionAnnotation as LiteLLMAnnotation
7
8
  from pydantic import BaseModel, Field
8
9
  from pydantic.json import pydantic_encoder
10
+ from typing_extensions import TypedDict
9
11
 
10
12
  from .content import Content, ImageData
11
13
  from .tool import ToolCall
@@ -14,6 +16,7 @@ from .tool import ToolCall
14
16
  class Role(StrEnum):
15
17
  USER = "user"
16
18
  SYSTEM = "system"
19
+ DEVELOPER = "developer"
17
20
  ASSISTANT = "assistant"
18
21
  TOOL = "tool"
19
22
 
@@ -23,11 +26,32 @@ class MessageBase(BaseModel):
23
26
  name: str | None = None
24
27
 
25
28
 
29
+ class ChatCompletionCachedContent(TypedDict):
30
+ type: Literal["ephemeral"]
31
+
32
+
33
+ class ThinkingBlock(TypedDict, total=False):
34
+ type: Required[Literal["thinking"]]
35
+ thinking: str
36
+ signature: str | None
37
+ cache_control: dict[str, Any] | ChatCompletionCachedContent | None
38
+
39
+
40
+ class RedactedThinkingBlock(TypedDict, total=False):
41
+ type: Required[Literal["redacted_thinking"]]
42
+ data: str
43
+ cache_control: dict[str, Any] | ChatCompletionCachedContent | None
44
+
45
+
26
46
  class AssistantMessage(MessageBase):
27
47
  role: Literal[Role.ASSISTANT] = Role.ASSISTANT
28
48
  content: str | None
29
49
  tool_calls: Sequence[ToolCall] | None = None
30
50
  refusal: str | None = None
51
+ reasoning_content: str | None = None
52
+ thinking_blocks: Sequence[ThinkingBlock | RedactedThinkingBlock] | None = None
53
+ annotations: Sequence[LiteLLMAnnotation] | None = None
54
+ provider_specific_fields: dict[str, Any] | None = None
31
55
 
32
56
 
33
57
  class UserMessage(MessageBase):
@@ -48,6 +48,8 @@ class BaseTool(
48
48
  name: str
49
49
  description: str
50
50
 
51
+ strict: bool | None = None
52
+
51
53
  _in_type: type[_InT] = PrivateAttr()
52
54
  _out_type: type[_OutT_co] = PrivateAttr()
53
55
 
@@ -20,7 +20,6 @@ CostsDict: TypeAlias = dict[str, ModelCostsDict]
20
20
 
21
21
 
22
22
  class UsageTracker(BaseModel):
23
- # TODO: specify different costs per provider:model, not just per model
24
23
  costs_dict_path: str | Path = COSTS_DICT_PATH
25
24
  costs_dict: CostsDict | None = None
26
25
  usages: dict[str, Usage] = Field(default_factory=dict)
@@ -29,34 +28,6 @@ class UsageTracker(BaseModel):
29
28
  super().__init__(**kwargs)
30
29
  self.costs_dict = self.load_costs_dict()
31
30
 
32
- def load_costs_dict(self) -> CostsDict | None:
33
- try:
34
- with Path(self.costs_dict_path).open() as f:
35
- return yaml.safe_load(f)["costs"]
36
- except Exception:
37
- logger.info(f"Failed to load cost dictionary from {self.costs_dict_path}")
38
- return None
39
-
40
- def _add_cost_to_usage(
41
- self, usage: Usage, model_costs_dict: ModelCostsDict
42
- ) -> None:
43
- in_rate = model_costs_dict["input"]
44
- out_rate = model_costs_dict["output"]
45
- cached_discount = model_costs_dict.get("cached_discount")
46
- input_cost = in_rate * usage.input_tokens
47
- output_cost = out_rate * usage.output_tokens
48
- reasoning_cost = (
49
- out_rate * usage.reasoning_tokens
50
- if usage.reasoning_tokens is not None
51
- else 0.0
52
- )
53
- cached_cost: float = (
54
- cached_discount * in_rate * usage.cached_tokens
55
- if (usage.cached_tokens is not None) and (cached_discount is not None)
56
- else 0.0
57
- )
58
- usage.cost = (input_cost + output_cost + reasoning_cost + cached_cost) / 1e6
59
-
60
31
  def update(
61
32
  self,
62
33
  agent_name: str,
@@ -64,13 +35,13 @@ class UsageTracker(BaseModel):
64
35
  model_name: str | None = None,
65
36
  ) -> None:
66
37
  if model_name is not None and self.costs_dict is not None:
67
- model_costs_dict = self.costs_dict.get(model_name.split(":", 1)[-1])
38
+ model_costs_dict = self.costs_dict.get(model_name.split("/", 1)[-1])
68
39
  else:
69
40
  model_costs_dict = None
70
41
 
71
42
  for completion in completions:
72
43
  if completion.usage is not None:
73
- if model_costs_dict is not None:
44
+ if completion.usage.cost is None and model_costs_dict is not None:
74
45
  self._add_cost_to_usage(
75
46
  usage=completion.usage, model_costs_dict=model_costs_dict
76
47
  )
@@ -100,9 +71,32 @@ class UsageTracker(BaseModel):
100
71
  logger.debug(colored(token_usage_str, "light_grey"))
101
72
 
102
73
  if usage.cost is not None:
103
- logger.debug(
104
- colored(
105
- f"Total cost: ${usage.cost:.4f}",
106
- "light_grey",
107
- )
108
- )
74
+ logger.debug(colored(f"Total cost: ${usage.cost:.4f}", "light_grey"))
75
+
76
+ def load_costs_dict(self) -> CostsDict | None:
77
+ try:
78
+ with Path(self.costs_dict_path).open() as f:
79
+ return yaml.safe_load(f)["costs"]
80
+ except Exception:
81
+ logger.info(f"Failed to load cost dictionary from {self.costs_dict_path}")
82
+ return None
83
+
84
+ def _add_cost_to_usage(
85
+ self, usage: Usage, model_costs_dict: ModelCostsDict
86
+ ) -> None:
87
+ in_rate = model_costs_dict["input"]
88
+ out_rate = model_costs_dict["output"]
89
+ cached_discount = model_costs_dict.get("cached_discount")
90
+ input_cost = in_rate * usage.input_tokens
91
+ output_cost = out_rate * usage.output_tokens
92
+ reasoning_cost = (
93
+ out_rate * usage.reasoning_tokens
94
+ if usage.reasoning_tokens is not None
95
+ else 0.0
96
+ )
97
+ cached_cost: float = (
98
+ cached_discount * in_rate * usage.cached_tokens
99
+ if (usage.cached_tokens is not None) and (cached_discount is not None)
100
+ else 0.0
101
+ )
102
+ usage.cost = (input_cost + output_cost + reasoning_cost + cached_cost) / 1e6