grasp_agents 0.4.7__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. grasp_agents/cloud_llm.py +191 -224
  2. grasp_agents/comm_processor.py +101 -100
  3. grasp_agents/errors.py +69 -9
  4. grasp_agents/litellm/__init__.py +106 -0
  5. grasp_agents/litellm/completion_chunk_converters.py +68 -0
  6. grasp_agents/litellm/completion_converters.py +72 -0
  7. grasp_agents/litellm/converters.py +138 -0
  8. grasp_agents/litellm/lite_llm.py +210 -0
  9. grasp_agents/litellm/message_converters.py +66 -0
  10. grasp_agents/llm.py +84 -49
  11. grasp_agents/llm_agent.py +136 -120
  12. grasp_agents/llm_agent_memory.py +3 -3
  13. grasp_agents/llm_policy_executor.py +167 -174
  14. grasp_agents/memory.py +4 -0
  15. grasp_agents/openai/__init__.py +24 -9
  16. grasp_agents/openai/completion_chunk_converters.py +6 -6
  17. grasp_agents/openai/completion_converters.py +12 -14
  18. grasp_agents/openai/content_converters.py +1 -3
  19. grasp_agents/openai/converters.py +6 -8
  20. grasp_agents/openai/message_converters.py +21 -3
  21. grasp_agents/openai/openai_llm.py +155 -103
  22. grasp_agents/openai/tool_converters.py +4 -6
  23. grasp_agents/packet.py +5 -2
  24. grasp_agents/packet_pool.py +14 -13
  25. grasp_agents/printer.py +234 -72
  26. grasp_agents/processor.py +228 -88
  27. grasp_agents/prompt_builder.py +2 -2
  28. grasp_agents/run_context.py +11 -20
  29. grasp_agents/runner.py +42 -0
  30. grasp_agents/typing/completion.py +16 -9
  31. grasp_agents/typing/completion_chunk.py +51 -22
  32. grasp_agents/typing/events.py +95 -19
  33. grasp_agents/typing/message.py +25 -1
  34. grasp_agents/typing/tool.py +2 -0
  35. grasp_agents/usage_tracker.py +31 -37
  36. grasp_agents/utils.py +95 -84
  37. grasp_agents/workflow/looped_workflow.py +60 -11
  38. grasp_agents/workflow/sequential_workflow.py +43 -11
  39. grasp_agents/workflow/workflow_processor.py +25 -24
  40. {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.0.dist-info}/METADATA +7 -6
  41. grasp_agents-0.5.0.dist-info/RECORD +57 -0
  42. grasp_agents-0.4.7.dist-info/RECORD +0 -50
  43. {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.0.dist-info}/WHEEL +0 -0
  44. {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.0.dist-info}/licenses/LICENSE.md +0 -0
@@ -2,24 +2,18 @@ import logging
2
2
  from collections.abc import AsyncIterator, Sequence
3
3
  from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast
4
4
 
5
- from pydantic import BaseModel
6
- from pydantic.json_schema import SkipJsonSchema
7
-
5
+ from .errors import PacketRoutingError
8
6
  from .memory import MemT
9
7
  from .packet import Packet
10
8
  from .packet_pool import PacketPool
11
9
  from .processor import Processor
12
10
  from .run_context import CtxT, RunContext
13
- from .typing.events import Event, PacketEvent
11
+ from .typing.events import Event, ProcPacketOutputEvent, RunResultEvent
14
12
  from .typing.io import InT, OutT_co, ProcName
15
13
 
16
14
  logger = logging.getLogger(__name__)
17
15
 
18
16
 
19
- class DynCommPayload(BaseModel):
20
- selected_recipients: SkipJsonSchema[Sequence[ProcName]]
21
-
22
-
23
17
  _OutT_contra = TypeVar("_OutT_contra", contravariant=True)
24
18
 
25
19
 
@@ -27,10 +21,16 @@ class ExitCommunicationHandler(Protocol[_OutT_contra, CtxT]):
27
21
  def __call__(
28
22
  self,
29
23
  out_packet: Packet[_OutT_contra],
30
- ctx: RunContext[CtxT] | None,
24
+ ctx: RunContext[CtxT],
31
25
  ) -> bool: ...
32
26
 
33
27
 
28
+ class SetRecipientsHandler(Protocol[_OutT_contra, CtxT]):
29
+ def __call__(
30
+ self, out_packet: Packet[_OutT_contra], ctx: RunContext[CtxT]
31
+ ) -> None: ...
32
+
33
+
34
34
  class CommProcessor(
35
35
  Processor[InT, OutT_co, MemT, CtxT],
36
36
  Generic[InT, OutT_co, MemT, CtxT],
@@ -46,49 +46,45 @@ class CommProcessor(
46
46
  *,
47
47
  recipients: Sequence[ProcName] | None = None,
48
48
  packet_pool: PacketPool[CtxT] | None = None,
49
- num_par_run_retries: int = 0,
49
+ max_retries: int = 0,
50
50
  ) -> None:
51
- super().__init__(name=name, num_par_run_retries=num_par_run_retries)
51
+ super().__init__(name=name, max_retries=max_retries)
52
52
 
53
53
  self.recipients = recipients or []
54
-
55
54
  self._packet_pool = packet_pool
56
55
  self._is_listening = False
56
+
57
57
  self._exit_communication_impl: (
58
58
  ExitCommunicationHandler[OutT_co, CtxT] | None
59
59
  ) = None
60
+ self._set_recipients_impl: SetRecipientsHandler[OutT_co, CtxT] | None = None
60
61
 
61
62
  @property
62
63
  def packet_pool(self) -> PacketPool[CtxT] | None:
63
64
  return self._packet_pool
64
65
 
65
- def _validate_routing(self, payloads: Sequence[OutT_co]) -> Sequence[ProcName]:
66
- if all(isinstance(p, DynCommPayload) for p in payloads):
67
- payloads_ = cast("Sequence[DynCommPayload]", payloads)
68
- selected_recipients_per_payload = [
69
- set(p.selected_recipients or []) for p in payloads_
70
- ]
71
- assert all(
72
- x == selected_recipients_per_payload[0]
73
- for x in selected_recipients_per_payload
74
- ), "All payloads must have the same recipient IDs for dynamic routing"
75
-
76
- assert payloads_[0].selected_recipients is not None
77
- selected_recipients = payloads_[0].selected_recipients
66
+ @property
67
+ def is_listening(self) -> bool:
68
+ return self._is_listening
78
69
 
79
- assert all(rid in self.recipients for rid in selected_recipients), (
80
- "Dynamic routing is enabled, but recipient IDs are not in "
81
- "the allowed agent's recipient IDs"
82
- )
70
+ def _set_recipients(
71
+ self, out_packet: Packet[OutT_co], ctx: RunContext[CtxT]
72
+ ) -> None:
73
+ if self._set_recipients_impl:
74
+ self._set_recipients_impl(out_packet=out_packet, ctx=ctx)
75
+ return
83
76
 
84
- return selected_recipients
77
+ out_packet.recipients = self.recipients
85
78
 
86
- if all((not isinstance(p, DynCommPayload)) for p in payloads):
87
- return self.recipients
79
+ def _validate_routing(self, recipients: Sequence[ProcName]) -> Sequence[ProcName]:
80
+ for r in recipients:
81
+ if r not in self.recipients:
82
+ raise PacketRoutingError(
83
+ selected_recipient=r,
84
+ allowed_recipients=cast("list[str]", self.recipients),
85
+ )
88
86
 
89
- raise ValueError(
90
- "All payloads must be either DCommAgentPayload or not DCommAgentPayload"
91
- )
87
+ return self.recipients
92
88
 
93
89
  async def run(
94
90
  self,
@@ -97,117 +93,122 @@ class CommProcessor(
97
93
  in_packet: Packet[InT] | None = None,
98
94
  in_args: InT | Sequence[InT] | None = None,
99
95
  forgetful: bool = False,
100
- run_id: str | None = None,
96
+ call_id: str | None = None,
101
97
  ctx: RunContext[CtxT] | None = None,
102
98
  ) -> Packet[OutT_co]:
99
+ call_id = self._generate_call_id(call_id)
100
+
103
101
  out_packet = await super().run(
104
102
  chat_inputs=chat_inputs,
105
103
  in_packet=in_packet,
106
104
  in_args=in_args,
107
105
  forgetful=forgetful,
108
- run_id=run_id,
106
+ call_id=call_id,
109
107
  ctx=ctx,
110
108
  )
111
- recipients = self._validate_routing(out_packet.payloads)
112
- routed_out_packet = Packet(
113
- payloads=out_packet.payloads, sender=self.name, recipients=recipients
114
- )
115
- if self._packet_pool is not None and in_packet is None and in_args is None:
116
- # If no input packet or args, we assume this is the first run.
117
- await self._packet_pool.post(routed_out_packet)
118
109
 
119
- return routed_out_packet
110
+ if self._packet_pool is not None:
111
+ if ctx is None:
112
+ raise ValueError("RunContext must be provided when using PacketPool")
113
+ if self._exit_communication(out_packet=out_packet, ctx=ctx):
114
+ ctx.result = out_packet
115
+ await self._packet_pool.stop_all()
116
+ return out_packet
117
+
118
+ self._set_recipients(out_packet=out_packet, ctx=ctx)
119
+ out_packet.recipients = self._validate_routing(out_packet.recipients)
120
+
121
+ await self._packet_pool.post(out_packet)
122
+
123
+ return out_packet
120
124
 
121
125
  async def run_stream(
122
126
  self,
123
127
  chat_inputs: Any | None = None,
124
128
  *,
125
129
  in_packet: Packet[InT] | None = None,
126
- in_args: InT | None = None,
130
+ in_args: InT | Sequence[InT] | None = None,
127
131
  forgetful: bool = False,
128
- run_id: str | None = None,
132
+ call_id: str | None = None,
129
133
  ctx: RunContext[CtxT] | None = None,
130
134
  ) -> AsyncIterator[Event[Any]]:
135
+ call_id = self._generate_call_id(call_id)
136
+
131
137
  out_packet: Packet[OutT_co] | None = None
132
138
  async for event in super().run_stream(
133
139
  chat_inputs=chat_inputs,
134
140
  in_packet=in_packet,
135
141
  in_args=in_args,
136
142
  forgetful=forgetful,
137
- run_id=run_id,
143
+ call_id=call_id,
138
144
  ctx=ctx,
139
145
  ):
140
- if isinstance(event, PacketEvent):
146
+ if isinstance(event, ProcPacketOutputEvent):
141
147
  out_packet = event.data
142
148
  else:
143
149
  yield event
144
150
 
145
151
  if out_packet is None:
146
- raise RuntimeError("No output packet generated during stream run")
152
+ return
153
+
154
+ if self._packet_pool is not None:
155
+ if ctx is None:
156
+ raise ValueError("RunContext must be provided when using PacketPool")
157
+ if self._exit_communication(out_packet=out_packet, ctx=ctx):
158
+ ctx.result = out_packet
159
+ yield RunResultEvent(
160
+ data=out_packet, proc_name=self.name, call_id=call_id
161
+ )
162
+ await self._packet_pool.stop_all()
163
+ return
164
+
165
+ self._set_recipients(out_packet=out_packet, ctx=ctx)
166
+ out_packet.recipients = self._validate_routing(out_packet.recipients)
167
+
168
+ await self._packet_pool.post(out_packet)
147
169
 
148
- recipients = self._validate_routing(out_packet.payloads)
149
- routed_out_packet = Packet(
150
- payloads=out_packet.payloads, sender=self.name, recipients=recipients
170
+ yield ProcPacketOutputEvent(
171
+ data=out_packet, proc_name=self.name, call_id=call_id
151
172
  )
152
- if self._packet_pool is not None and in_packet is None and in_args is None:
153
- # If no input packet or args, we assume this is the first run.
154
- await self._packet_pool.post(routed_out_packet)
155
173
 
156
- yield PacketEvent(data=routed_out_packet, name=self.name)
174
+ def start_listening(self, ctx: RunContext[CtxT], **run_kwargs: Any) -> None:
175
+ if self._packet_pool is None:
176
+ raise RuntimeError("Packet pool must be initialized before listening")
157
177
 
158
- def exit_communication(
159
- self, func: ExitCommunicationHandler[OutT_co, CtxT]
160
- ) -> ExitCommunicationHandler[OutT_co, CtxT]:
161
- self._exit_communication_impl = func
178
+ if self._is_listening:
179
+ return
180
+ self._is_listening = True
162
181
 
163
- return func
182
+ self._packet_pool.register_packet_handler(
183
+ processor_name=self.name,
184
+ handler=self.run_stream if ctx.is_streaming else self.run, # type: ignore[call-arg]
185
+ ctx=ctx,
186
+ **run_kwargs,
187
+ )
164
188
 
165
189
  def _exit_communication(
166
- self, out_packet: Packet[OutT_co], ctx: RunContext[CtxT] | None
190
+ self, out_packet: Packet[OutT_co], ctx: RunContext[CtxT]
167
191
  ) -> bool:
168
192
  if self._exit_communication_impl:
169
193
  return self._exit_communication_impl(out_packet=out_packet, ctx=ctx)
170
194
 
171
195
  return False
172
196
 
173
- async def _packet_handler(
174
- self,
175
- packet: Packet[InT],
176
- ctx: RunContext[CtxT] | None = None,
177
- **run_kwargs: Any,
178
- ) -> None:
179
- assert self._packet_pool is not None, "Packet pool must be initialized"
180
-
181
- out_packet = await self.run(ctx=ctx, in_packet=packet, **run_kwargs)
182
-
183
- if self._exit_communication(out_packet=out_packet, ctx=ctx):
184
- await self._packet_pool.stop_all()
185
- return
186
-
187
- await self._packet_pool.post(out_packet)
188
-
189
- @property
190
- def is_listening(self) -> bool:
191
- return self._is_listening
192
-
193
- async def start_listening(
194
- self, ctx: RunContext[CtxT] | None = None, **run_kwargs: Any
195
- ) -> None:
196
- assert self._packet_pool is not None, "Packet pool must be initialized"
197
+ def exit_communication(
198
+ self, func: ExitCommunicationHandler[OutT_co, CtxT]
199
+ ) -> ExitCommunicationHandler[OutT_co, CtxT]:
200
+ self._exit_communication_impl = func
197
201
 
198
- if self._is_listening:
199
- return
202
+ return func
200
203
 
201
- self._is_listening = True
202
- self._packet_pool.register_packet_handler(
203
- processor_name=self.name,
204
- handler=self._packet_handler,
205
- ctx=ctx,
206
- **run_kwargs,
207
- )
204
+ def set_recipients(
205
+ self, func: SetRecipientsHandler[OutT_co, CtxT]
206
+ ) -> SetRecipientsHandler[OutT_co, CtxT]:
207
+ self._select_recipients_impl = func
208
208
 
209
- async def stop_listening(self) -> None:
210
- assert self._packet_pool is not None, "Packet pool must be initialized"
209
+ return func
211
210
 
212
- self._is_listening = False
213
- await self._packet_pool.unregister_packet_handler(self.name)
211
+ # async def stop_listening(self) -> None:
212
+ # assert self._packet_pool is not None
213
+ # self._is_listening = False
214
+ # await self._packet_pool.unregister_packet_handler(self.name)
grasp_agents/errors.py CHANGED
@@ -1,29 +1,46 @@
1
- class InputValidationError(Exception):
1
+ # from openai import APIResponseValidationError
2
+ class CompletionError(Exception):
2
3
  pass
3
4
 
4
5
 
5
- class StringParsingError(Exception):
6
+ class CombineCompletionChunksError(Exception):
6
7
  pass
7
8
 
8
9
 
9
- class CompletionError(Exception):
10
+ class ProcInputValidationError(Exception):
10
11
  pass
11
12
 
12
13
 
13
- class CombineCompletionChunksError(Exception):
14
+ class ProcOutputValidationError(Exception):
14
15
  pass
15
16
 
16
17
 
17
- class ToolValidationError(Exception):
18
- pass
18
+ class AgentFinalAnswerError(Exception):
19
+ def __init__(self, message: str | None = None) -> None:
20
+ super().__init__(
21
+ message or "Final answer tool call did not return a final answer message."
22
+ )
23
+ self.message = message
19
24
 
20
25
 
21
- class OutputValidationError(Exception):
26
+ class WorkflowConstructionError(Exception):
22
27
  pass
23
28
 
24
29
 
25
- class WorkflowConstructionError(Exception):
26
- pass
30
+ class PacketRoutingError(Exception):
31
+ def __init__(
32
+ self,
33
+ selected_recipient: str,
34
+ allowed_recipients: list[str],
35
+ message: str | None = None,
36
+ ) -> None:
37
+ default_message = (
38
+ f"Selected recipient '{selected_recipient}' is not in the allowed "
39
+ f"recipients: {allowed_recipients}"
40
+ )
41
+ super().__init__(message or default_message)
42
+ self.selected_recipient = selected_recipient
43
+ self.allowed_recipients = allowed_recipients
27
44
 
28
45
 
29
46
  class SystemPromptBuilderError(Exception):
@@ -32,3 +49,46 @@ class SystemPromptBuilderError(Exception):
32
49
 
33
50
  class InputPromptBuilderError(Exception):
34
51
  pass
52
+
53
+
54
+ class PyJSONStringParsingError(Exception):
55
+ def __init__(self, s: str, message: str | None = None) -> None:
56
+ super().__init__(
57
+ message
58
+ or "Both ast.literal_eval and json.loads failed to parse the following "
59
+ f"JSON/Python string:\n{s}"
60
+ )
61
+ self.s = s
62
+
63
+
64
+ class JSONSchemaValidationError(Exception):
65
+ def __init__(self, s: str, schema: object, message: str | None = None) -> None:
66
+ super().__init__(
67
+ message
68
+ or f"JSON schema validation failed for:\n{s}\nExpected type: {schema}"
69
+ )
70
+ self.s = s
71
+ self.schema = schema
72
+
73
+
74
+ class LLMToolCallValidationError(Exception):
75
+ def __init__(
76
+ self, tool_name: str, tool_args: str, message: str | None = None
77
+ ) -> None:
78
+ super().__init__(
79
+ message
80
+ or f"Failed to validate tool call '{tool_name}' with arguments:"
81
+ f"\n{tool_args}."
82
+ )
83
+ self.tool_name = tool_name
84
+ self.tool_args = tool_args
85
+
86
+
87
+ class LLMResponseValidationError(JSONSchemaValidationError):
88
+ def __init__(self, s: str, schema: object, message: str | None = None) -> None:
89
+ super().__init__(
90
+ s,
91
+ schema,
92
+ message
93
+ or f"Failed to validate LLM response:\n{s}\nExpected type: {schema}",
94
+ )
@@ -0,0 +1,106 @@
1
+ # pyright: reportUnusedImport=false
2
+
3
+ from litellm.types.utils import ChatCompletionMessageToolCall as LiteLLMToolCall
4
+ from litellm.types.utils import Choices as LiteLLMChoice
5
+ from litellm.types.utils import Function as LiteLLMFunction
6
+ from litellm.types.utils import Message as LiteLLMCompletionMessage
7
+ from litellm.types.utils import ModelResponse as LiteLLMCompletion
8
+ from litellm.types.utils import ModelResponseStream as LiteLLMCompletionChunk
9
+ from litellm.types.utils import StreamingChoices as LiteLLMChunkChoice
10
+ from litellm.types.utils import Usage as LiteLLMUsage
11
+ from openai._streaming import (
12
+ AsyncStream as OpenAIAsyncStream, # type: ignore[import] # noqa: PLC2701
13
+ )
14
+ from openai.types import CompletionUsage as OpenAIUsage
15
+ from openai.types.chat.chat_completion import ChatCompletion as OpenAICompletion
16
+ from openai.types.chat.chat_completion import (
17
+ ChoiceLogprobs as OpenAIChoiceLogprobs,
18
+ )
19
+ from openai.types.chat.chat_completion_assistant_message_param import (
20
+ ChatCompletionAssistantMessageParam as OpenAIAssistantMessageParam,
21
+ )
22
+ from openai.types.chat.chat_completion_chunk import (
23
+ ChatCompletionChunk as OpenAICompletionChunk,
24
+ )
25
+ from openai.types.chat.chat_completion_chunk import (
26
+ Choice as OpenAIChunkChoice,
27
+ )
28
+ from openai.types.chat.chat_completion_chunk import (
29
+ ChoiceDelta as OpenAIChunkChoiceDelta,
30
+ )
31
+ from openai.types.chat.chat_completion_chunk import (
32
+ ChoiceDeltaToolCall as OpenAIChunkChoiceDeltaToolCall,
33
+ )
34
+ from openai.types.chat.chat_completion_content_part_image_param import (
35
+ ChatCompletionContentPartImageParam as OpenAIContentPartImageParam,
36
+ )
37
+ from openai.types.chat.chat_completion_content_part_image_param import (
38
+ ImageURL as OpenAIImageURL,
39
+ )
40
+ from openai.types.chat.chat_completion_content_part_param import (
41
+ ChatCompletionContentPartParam as OpenAIContentPartParam,
42
+ )
43
+ from openai.types.chat.chat_completion_content_part_text_param import (
44
+ ChatCompletionContentPartTextParam as OpenAIContentPartTextParam,
45
+ )
46
+ from openai.types.chat.chat_completion_developer_message_param import (
47
+ ChatCompletionDeveloperMessageParam as OpenAIDeveloperMessageParam,
48
+ )
49
+ from openai.types.chat.chat_completion_function_message_param import (
50
+ ChatCompletionFunctionMessageParam as OpenAIFunctionMessageParam,
51
+ )
52
+ from openai.types.chat.chat_completion_message import (
53
+ ChatCompletionMessage as OpenAICompletionMessage,
54
+ )
55
+ from openai.types.chat.chat_completion_message_param import (
56
+ ChatCompletionMessageParam as OpenAIMessageParam,
57
+ )
58
+ from openai.types.chat.chat_completion_message_tool_call_param import (
59
+ ChatCompletionMessageToolCallParam as OpenAIToolCallParam,
60
+ )
61
+ from openai.types.chat.chat_completion_message_tool_call_param import (
62
+ Function as OpenAIToolCallFunction,
63
+ )
64
+ from openai.types.chat.chat_completion_named_tool_choice_param import (
65
+ ChatCompletionNamedToolChoiceParam as OpenAINamedToolChoiceParam,
66
+ )
67
+ from openai.types.chat.chat_completion_named_tool_choice_param import (
68
+ Function as OpenAINamedToolChoiceFunction,
69
+ )
70
+ from openai.types.chat.chat_completion_prediction_content_param import (
71
+ ChatCompletionPredictionContentParam as OpenAIPredictionContentParam,
72
+ )
73
+ from openai.types.chat.chat_completion_stream_options_param import (
74
+ ChatCompletionStreamOptionsParam as OpenAIStreamOptionsParam,
75
+ )
76
+ from openai.types.chat.chat_completion_system_message_param import (
77
+ ChatCompletionSystemMessageParam as OpenAISystemMessageParam,
78
+ )
79
+ from openai.types.chat.chat_completion_tool_choice_option_param import (
80
+ ChatCompletionToolChoiceOptionParam as OpenAIToolChoiceOptionParam,
81
+ )
82
+ from openai.types.chat.chat_completion_tool_message_param import (
83
+ ChatCompletionToolMessageParam as OpenAIToolMessageParam,
84
+ )
85
+ from openai.types.chat.chat_completion_tool_param import (
86
+ ChatCompletionToolParam as OpenAIToolParam,
87
+ )
88
+ from openai.types.chat.chat_completion_user_message_param import (
89
+ ChatCompletionUserMessageParam as OpenAIUserMessageParam,
90
+ )
91
+ from openai.types.chat.parsed_chat_completion import (
92
+ ParsedChatCompletion as OpenAIParsedCompletion,
93
+ )
94
+ from openai.types.chat.parsed_chat_completion import (
95
+ ParsedChatCompletionMessage as OpenAIParsedCompletionMessage,
96
+ )
97
+ from openai.types.chat.parsed_chat_completion import (
98
+ ParsedChoice as OpenAIParsedChoice,
99
+ )
100
+ from openai.types.shared_params.function_definition import (
101
+ FunctionDefinition as OpenAIFunctionDefinition,
102
+ )
103
+
104
+ from .lite_llm import LiteLLM, LiteLLMSettings
105
+
106
+ __all__ = ["LiteLLM", "LiteLLMSettings"]
@@ -0,0 +1,68 @@
1
+ from ..openai.completion_converters import from_api_completion_usage
2
+ from ..typing.completion_chunk import (
3
+ CompletionChunk,
4
+ CompletionChunkChoice,
5
+ CompletionChunkChoiceDelta,
6
+ CompletionChunkDeltaToolCall,
7
+ )
8
+ from . import LiteLLMChunkChoice, LiteLLMCompletionChunk
9
+
10
+
11
+ def from_api_completion_chunk(
12
+ api_completion_chunk: LiteLLMCompletionChunk, name: str | None = None
13
+ ) -> CompletionChunk:
14
+ choices: list[CompletionChunkChoice] = []
15
+
16
+ for api_choice in api_completion_chunk.choices:
17
+ assert isinstance(api_choice, LiteLLMChunkChoice)
18
+
19
+ api_delta = api_choice.delta
20
+
21
+ delta = CompletionChunkChoiceDelta(
22
+ tool_calls=[
23
+ CompletionChunkDeltaToolCall(
24
+ id=tool_call.id,
25
+ index=tool_call.index,
26
+ tool_name=tool_call.function.name,
27
+ tool_arguments=tool_call.function.arguments,
28
+ )
29
+ for tool_call in (api_delta.tool_calls or [])
30
+ if tool_call.function
31
+ ],
32
+ content=api_delta.content, # type: ignore[assignment, arg-type]
33
+ role=api_delta.role, # type: ignore[assignment, arg-type]
34
+ thinking_blocks=getattr(api_delta, "thinking_blocks", None),
35
+ annotations=getattr(api_delta, "annotations", None),
36
+ reasoning_content=getattr(api_delta, "reasoning_content", None),
37
+ provider_specific_fields=api_delta.provider_specific_fields,
38
+ refusal=getattr(api_delta, "refusal", None),
39
+ )
40
+
41
+ choice = CompletionChunkChoice(
42
+ delta=delta,
43
+ index=api_choice.index,
44
+ finish_reason=api_choice.finish_reason, # type: ignore[assignment, arg-type]
45
+ logprobs=getattr(api_choice, "logprobs", None),
46
+ )
47
+
48
+ choices.append(choice)
49
+
50
+ api_usage = getattr(api_completion_chunk, "usage", None)
51
+ usage = None
52
+ if api_usage is not None:
53
+ usage = from_api_completion_usage(api_usage)
54
+ hidden_params = getattr(api_completion_chunk, "_hidden_params", {})
55
+ usage.cost = getattr(hidden_params, "response_cost", None)
56
+
57
+ return CompletionChunk(
58
+ id=api_completion_chunk.id,
59
+ model=api_completion_chunk.model,
60
+ name=name,
61
+ created=api_completion_chunk.created,
62
+ system_fingerprint=api_completion_chunk.system_fingerprint,
63
+ choices=choices,
64
+ usage=usage,
65
+ provider_specific_fields=api_completion_chunk.provider_specific_fields,
66
+ hidden_params=api_completion_chunk._hidden_params, # type: ignore[union-attr]
67
+ response_ms=getattr(api_completion_chunk, "_response_ms", None),
68
+ )
@@ -0,0 +1,72 @@
1
+ from typing import cast
2
+
3
+ from ..typing.completion import Completion, CompletionChoice, Usage
4
+ from . import LiteLLMChoice, LiteLLMCompletion, LiteLLMUsage
5
+ from .message_converters import from_api_assistant_message
6
+
7
+
8
+ def from_api_completion_usage(api_usage: LiteLLMUsage) -> Usage:
9
+ reasoning_tokens = None
10
+ cached_tokens = None
11
+
12
+ if api_usage.completion_tokens_details is not None:
13
+ reasoning_tokens = api_usage.completion_tokens_details.reasoning_tokens
14
+ if api_usage.prompt_tokens_details is not None:
15
+ cached_tokens = api_usage.prompt_tokens_details.cached_tokens
16
+
17
+ input_tokens = api_usage.prompt_tokens - (cached_tokens or 0)
18
+ output_tokens = api_usage.completion_tokens # - (reasoning_tokens or 0)
19
+
20
+ return Usage(
21
+ input_tokens=input_tokens,
22
+ output_tokens=output_tokens,
23
+ reasoning_tokens=reasoning_tokens,
24
+ cached_tokens=cached_tokens,
25
+ )
26
+
27
+
28
+ def from_api_completion(
29
+ api_completion: LiteLLMCompletion, name: str | None = None
30
+ ) -> Completion:
31
+ choices: list[CompletionChoice] = []
32
+ usage: Usage | None = None
33
+
34
+ for api_choice in api_completion.choices:
35
+ assert isinstance(api_choice, LiteLLMChoice)
36
+
37
+ message = from_api_assistant_message(api_choice.message, name=name)
38
+
39
+ choices.append(
40
+ CompletionChoice(
41
+ index=api_choice.index,
42
+ message=message,
43
+ finish_reason=api_choice.finish_reason, # type: ignore[assignment, arg-type]
44
+ logprobs=getattr(api_choice, "logprobs", None),
45
+ provider_specific_fields=getattr(
46
+ api_choice, "provider_specific_fields", None
47
+ ),
48
+ )
49
+ )
50
+
51
+ api_usage = getattr(api_completion, "usage", None)
52
+ usage = None
53
+ if api_usage:
54
+ usage = from_api_completion_usage(cast("LiteLLMUsage", api_usage))
55
+ hidden_params = getattr(api_completion, "_hidden_params", {})
56
+ usage.cost = hidden_params.get("response_cost")
57
+
58
+ return Completion(
59
+ id=api_completion.id,
60
+ created=api_completion.created,
61
+ usage=usage,
62
+ choices=choices,
63
+ name=name,
64
+ system_fingerprint=api_completion.system_fingerprint,
65
+ model=api_completion.model,
66
+ hidden_params=api_completion._hidden_params, # type: ignore[union-attr]
67
+ response_ms=getattr(api_completion, "_response_ms", None),
68
+ )
69
+
70
+
71
+ def to_api_completion(completion: Completion) -> LiteLLMCompletion:
72
+ raise NotImplementedError