grasp_agents 0.4.6__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/cloud_llm.py +191 -218
- grasp_agents/comm_processor.py +101 -100
- grasp_agents/errors.py +69 -9
- grasp_agents/litellm/__init__.py +106 -0
- grasp_agents/litellm/completion_chunk_converters.py +68 -0
- grasp_agents/litellm/completion_converters.py +72 -0
- grasp_agents/litellm/converters.py +138 -0
- grasp_agents/litellm/lite_llm.py +210 -0
- grasp_agents/litellm/message_converters.py +66 -0
- grasp_agents/llm.py +84 -49
- grasp_agents/llm_agent.py +136 -120
- grasp_agents/llm_agent_memory.py +3 -3
- grasp_agents/llm_policy_executor.py +167 -174
- grasp_agents/memory.py +4 -0
- grasp_agents/openai/__init__.py +24 -9
- grasp_agents/openai/completion_chunk_converters.py +6 -6
- grasp_agents/openai/completion_converters.py +12 -14
- grasp_agents/openai/content_converters.py +1 -3
- grasp_agents/openai/converters.py +6 -8
- grasp_agents/openai/message_converters.py +21 -3
- grasp_agents/openai/openai_llm.py +155 -103
- grasp_agents/openai/tool_converters.py +4 -6
- grasp_agents/packet.py +5 -2
- grasp_agents/packet_pool.py +14 -13
- grasp_agents/printer.py +234 -72
- grasp_agents/processor.py +228 -88
- grasp_agents/prompt_builder.py +2 -2
- grasp_agents/run_context.py +11 -20
- grasp_agents/runner.py +42 -0
- grasp_agents/typing/completion.py +16 -9
- grasp_agents/typing/completion_chunk.py +51 -22
- grasp_agents/typing/events.py +95 -19
- grasp_agents/typing/message.py +25 -1
- grasp_agents/typing/tool.py +2 -0
- grasp_agents/usage_tracker.py +31 -37
- grasp_agents/utils.py +95 -84
- grasp_agents/workflow/looped_workflow.py +60 -11
- grasp_agents/workflow/sequential_workflow.py +43 -11
- grasp_agents/workflow/workflow_processor.py +25 -24
- {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/METADATA +7 -6
- grasp_agents-0.5.0.dist-info/RECORD +57 -0
- grasp_agents-0.4.6.dist-info/RECORD +0 -50
- {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/WHEEL +0 -0
- {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/comm_processor.py
CHANGED
@@ -2,24 +2,18 @@ import logging
|
|
2
2
|
from collections.abc import AsyncIterator, Sequence
|
3
3
|
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast
|
4
4
|
|
5
|
-
from
|
6
|
-
from pydantic.json_schema import SkipJsonSchema
|
7
|
-
|
5
|
+
from .errors import PacketRoutingError
|
8
6
|
from .memory import MemT
|
9
7
|
from .packet import Packet
|
10
8
|
from .packet_pool import PacketPool
|
11
9
|
from .processor import Processor
|
12
10
|
from .run_context import CtxT, RunContext
|
13
|
-
from .typing.events import Event,
|
11
|
+
from .typing.events import Event, ProcPacketOutputEvent, RunResultEvent
|
14
12
|
from .typing.io import InT, OutT_co, ProcName
|
15
13
|
|
16
14
|
logger = logging.getLogger(__name__)
|
17
15
|
|
18
16
|
|
19
|
-
class DynCommPayload(BaseModel):
|
20
|
-
selected_recipients: SkipJsonSchema[Sequence[ProcName]]
|
21
|
-
|
22
|
-
|
23
17
|
_OutT_contra = TypeVar("_OutT_contra", contravariant=True)
|
24
18
|
|
25
19
|
|
@@ -27,10 +21,16 @@ class ExitCommunicationHandler(Protocol[_OutT_contra, CtxT]):
|
|
27
21
|
def __call__(
|
28
22
|
self,
|
29
23
|
out_packet: Packet[_OutT_contra],
|
30
|
-
ctx: RunContext[CtxT]
|
24
|
+
ctx: RunContext[CtxT],
|
31
25
|
) -> bool: ...
|
32
26
|
|
33
27
|
|
28
|
+
class SetRecipientsHandler(Protocol[_OutT_contra, CtxT]):
|
29
|
+
def __call__(
|
30
|
+
self, out_packet: Packet[_OutT_contra], ctx: RunContext[CtxT]
|
31
|
+
) -> None: ...
|
32
|
+
|
33
|
+
|
34
34
|
class CommProcessor(
|
35
35
|
Processor[InT, OutT_co, MemT, CtxT],
|
36
36
|
Generic[InT, OutT_co, MemT, CtxT],
|
@@ -46,49 +46,45 @@ class CommProcessor(
|
|
46
46
|
*,
|
47
47
|
recipients: Sequence[ProcName] | None = None,
|
48
48
|
packet_pool: PacketPool[CtxT] | None = None,
|
49
|
-
|
49
|
+
max_retries: int = 0,
|
50
50
|
) -> None:
|
51
|
-
super().__init__(name=name,
|
51
|
+
super().__init__(name=name, max_retries=max_retries)
|
52
52
|
|
53
53
|
self.recipients = recipients or []
|
54
|
-
|
55
54
|
self._packet_pool = packet_pool
|
56
55
|
self._is_listening = False
|
56
|
+
|
57
57
|
self._exit_communication_impl: (
|
58
58
|
ExitCommunicationHandler[OutT_co, CtxT] | None
|
59
59
|
) = None
|
60
|
+
self._set_recipients_impl: SetRecipientsHandler[OutT_co, CtxT] | None = None
|
60
61
|
|
61
62
|
@property
|
62
63
|
def packet_pool(self) -> PacketPool[CtxT] | None:
|
63
64
|
return self._packet_pool
|
64
65
|
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
selected_recipients_per_payload = [
|
69
|
-
set(p.selected_recipients or []) for p in payloads_
|
70
|
-
]
|
71
|
-
assert all(
|
72
|
-
x == selected_recipients_per_payload[0]
|
73
|
-
for x in selected_recipients_per_payload
|
74
|
-
), "All payloads must have the same recipient IDs for dynamic routing"
|
75
|
-
|
76
|
-
assert payloads_[0].selected_recipients is not None
|
77
|
-
selected_recipients = payloads_[0].selected_recipients
|
66
|
+
@property
|
67
|
+
def is_listening(self) -> bool:
|
68
|
+
return self._is_listening
|
78
69
|
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
70
|
+
def _set_recipients(
|
71
|
+
self, out_packet: Packet[OutT_co], ctx: RunContext[CtxT]
|
72
|
+
) -> None:
|
73
|
+
if self._set_recipients_impl:
|
74
|
+
self._set_recipients_impl(out_packet=out_packet, ctx=ctx)
|
75
|
+
return
|
83
76
|
|
84
|
-
|
77
|
+
out_packet.recipients = self.recipients
|
85
78
|
|
86
|
-
|
87
|
-
|
79
|
+
def _validate_routing(self, recipients: Sequence[ProcName]) -> Sequence[ProcName]:
|
80
|
+
for r in recipients:
|
81
|
+
if r not in self.recipients:
|
82
|
+
raise PacketRoutingError(
|
83
|
+
selected_recipient=r,
|
84
|
+
allowed_recipients=cast("list[str]", self.recipients),
|
85
|
+
)
|
88
86
|
|
89
|
-
|
90
|
-
"All payloads must be either DCommAgentPayload or not DCommAgentPayload"
|
91
|
-
)
|
87
|
+
return self.recipients
|
92
88
|
|
93
89
|
async def run(
|
94
90
|
self,
|
@@ -97,117 +93,122 @@ class CommProcessor(
|
|
97
93
|
in_packet: Packet[InT] | None = None,
|
98
94
|
in_args: InT | Sequence[InT] | None = None,
|
99
95
|
forgetful: bool = False,
|
100
|
-
|
96
|
+
call_id: str | None = None,
|
101
97
|
ctx: RunContext[CtxT] | None = None,
|
102
98
|
) -> Packet[OutT_co]:
|
99
|
+
call_id = self._generate_call_id(call_id)
|
100
|
+
|
103
101
|
out_packet = await super().run(
|
104
102
|
chat_inputs=chat_inputs,
|
105
103
|
in_packet=in_packet,
|
106
104
|
in_args=in_args,
|
107
105
|
forgetful=forgetful,
|
108
|
-
|
106
|
+
call_id=call_id,
|
109
107
|
ctx=ctx,
|
110
108
|
)
|
111
|
-
recipients = self._validate_routing(out_packet.payloads)
|
112
|
-
routed_out_packet = Packet(
|
113
|
-
payloads=out_packet.payloads, sender=self.name, recipients=recipients
|
114
|
-
)
|
115
|
-
if self._packet_pool is not None and in_packet is None and in_args is None:
|
116
|
-
# If no input packet or args, we assume this is the first run.
|
117
|
-
await self._packet_pool.post(routed_out_packet)
|
118
109
|
|
119
|
-
|
110
|
+
if self._packet_pool is not None:
|
111
|
+
if ctx is None:
|
112
|
+
raise ValueError("RunContext must be provided when using PacketPool")
|
113
|
+
if self._exit_communication(out_packet=out_packet, ctx=ctx):
|
114
|
+
ctx.result = out_packet
|
115
|
+
await self._packet_pool.stop_all()
|
116
|
+
return out_packet
|
117
|
+
|
118
|
+
self._set_recipients(out_packet=out_packet, ctx=ctx)
|
119
|
+
out_packet.recipients = self._validate_routing(out_packet.recipients)
|
120
|
+
|
121
|
+
await self._packet_pool.post(out_packet)
|
122
|
+
|
123
|
+
return out_packet
|
120
124
|
|
121
125
|
async def run_stream(
|
122
126
|
self,
|
123
127
|
chat_inputs: Any | None = None,
|
124
128
|
*,
|
125
129
|
in_packet: Packet[InT] | None = None,
|
126
|
-
in_args: InT | None = None,
|
130
|
+
in_args: InT | Sequence[InT] | None = None,
|
127
131
|
forgetful: bool = False,
|
128
|
-
|
132
|
+
call_id: str | None = None,
|
129
133
|
ctx: RunContext[CtxT] | None = None,
|
130
134
|
) -> AsyncIterator[Event[Any]]:
|
135
|
+
call_id = self._generate_call_id(call_id)
|
136
|
+
|
131
137
|
out_packet: Packet[OutT_co] | None = None
|
132
138
|
async for event in super().run_stream(
|
133
139
|
chat_inputs=chat_inputs,
|
134
140
|
in_packet=in_packet,
|
135
141
|
in_args=in_args,
|
136
142
|
forgetful=forgetful,
|
137
|
-
|
143
|
+
call_id=call_id,
|
138
144
|
ctx=ctx,
|
139
145
|
):
|
140
|
-
if isinstance(event,
|
146
|
+
if isinstance(event, ProcPacketOutputEvent):
|
141
147
|
out_packet = event.data
|
142
148
|
else:
|
143
149
|
yield event
|
144
150
|
|
145
151
|
if out_packet is None:
|
146
|
-
|
152
|
+
return
|
153
|
+
|
154
|
+
if self._packet_pool is not None:
|
155
|
+
if ctx is None:
|
156
|
+
raise ValueError("RunContext must be provided when using PacketPool")
|
157
|
+
if self._exit_communication(out_packet=out_packet, ctx=ctx):
|
158
|
+
ctx.result = out_packet
|
159
|
+
yield RunResultEvent(
|
160
|
+
data=out_packet, proc_name=self.name, call_id=call_id
|
161
|
+
)
|
162
|
+
await self._packet_pool.stop_all()
|
163
|
+
return
|
164
|
+
|
165
|
+
self._set_recipients(out_packet=out_packet, ctx=ctx)
|
166
|
+
out_packet.recipients = self._validate_routing(out_packet.recipients)
|
167
|
+
|
168
|
+
await self._packet_pool.post(out_packet)
|
147
169
|
|
148
|
-
|
149
|
-
|
150
|
-
payloads=out_packet.payloads, sender=self.name, recipients=recipients
|
170
|
+
yield ProcPacketOutputEvent(
|
171
|
+
data=out_packet, proc_name=self.name, call_id=call_id
|
151
172
|
)
|
152
|
-
if self._packet_pool is not None and in_packet is None and in_args is None:
|
153
|
-
# If no input packet or args, we assume this is the first run.
|
154
|
-
await self._packet_pool.post(routed_out_packet)
|
155
173
|
|
156
|
-
|
174
|
+
def start_listening(self, ctx: RunContext[CtxT], **run_kwargs: Any) -> None:
|
175
|
+
if self._packet_pool is None:
|
176
|
+
raise RuntimeError("Packet pool must be initialized before listening")
|
157
177
|
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
self._exit_communication_impl = func
|
178
|
+
if self._is_listening:
|
179
|
+
return
|
180
|
+
self._is_listening = True
|
162
181
|
|
163
|
-
|
182
|
+
self._packet_pool.register_packet_handler(
|
183
|
+
processor_name=self.name,
|
184
|
+
handler=self.run_stream if ctx.is_streaming else self.run, # type: ignore[call-arg]
|
185
|
+
ctx=ctx,
|
186
|
+
**run_kwargs,
|
187
|
+
)
|
164
188
|
|
165
189
|
def _exit_communication(
|
166
|
-
self, out_packet: Packet[OutT_co], ctx: RunContext[CtxT]
|
190
|
+
self, out_packet: Packet[OutT_co], ctx: RunContext[CtxT]
|
167
191
|
) -> bool:
|
168
192
|
if self._exit_communication_impl:
|
169
193
|
return self._exit_communication_impl(out_packet=out_packet, ctx=ctx)
|
170
194
|
|
171
195
|
return False
|
172
196
|
|
173
|
-
|
174
|
-
self,
|
175
|
-
|
176
|
-
|
177
|
-
**run_kwargs: Any,
|
178
|
-
) -> None:
|
179
|
-
assert self._packet_pool is not None, "Packet pool must be initialized"
|
180
|
-
|
181
|
-
out_packet = await self.run(ctx=ctx, in_packet=packet, **run_kwargs)
|
182
|
-
|
183
|
-
if self._exit_communication(out_packet=out_packet, ctx=ctx):
|
184
|
-
await self._packet_pool.stop_all()
|
185
|
-
return
|
186
|
-
|
187
|
-
await self._packet_pool.post(out_packet)
|
188
|
-
|
189
|
-
@property
|
190
|
-
def is_listening(self) -> bool:
|
191
|
-
return self._is_listening
|
192
|
-
|
193
|
-
async def start_listening(
|
194
|
-
self, ctx: RunContext[CtxT] | None = None, **run_kwargs: Any
|
195
|
-
) -> None:
|
196
|
-
assert self._packet_pool is not None, "Packet pool must be initialized"
|
197
|
+
def exit_communication(
|
198
|
+
self, func: ExitCommunicationHandler[OutT_co, CtxT]
|
199
|
+
) -> ExitCommunicationHandler[OutT_co, CtxT]:
|
200
|
+
self._exit_communication_impl = func
|
197
201
|
|
198
|
-
|
199
|
-
return
|
202
|
+
return func
|
200
203
|
|
201
|
-
|
202
|
-
self
|
203
|
-
|
204
|
-
|
205
|
-
ctx=ctx,
|
206
|
-
**run_kwargs,
|
207
|
-
)
|
204
|
+
def set_recipients(
|
205
|
+
self, func: SetRecipientsHandler[OutT_co, CtxT]
|
206
|
+
) -> SetRecipientsHandler[OutT_co, CtxT]:
|
207
|
+
self._select_recipients_impl = func
|
208
208
|
|
209
|
-
|
210
|
-
assert self._packet_pool is not None, "Packet pool must be initialized"
|
209
|
+
return func
|
211
210
|
|
212
|
-
|
213
|
-
|
211
|
+
# async def stop_listening(self) -> None:
|
212
|
+
# assert self._packet_pool is not None
|
213
|
+
# self._is_listening = False
|
214
|
+
# await self._packet_pool.unregister_packet_handler(self.name)
|
grasp_agents/errors.py
CHANGED
@@ -1,29 +1,46 @@
|
|
1
|
-
|
1
|
+
# from openai import APIResponseValidationError
|
2
|
+
class CompletionError(Exception):
|
2
3
|
pass
|
3
4
|
|
4
5
|
|
5
|
-
class
|
6
|
+
class CombineCompletionChunksError(Exception):
|
6
7
|
pass
|
7
8
|
|
8
9
|
|
9
|
-
class
|
10
|
+
class ProcInputValidationError(Exception):
|
10
11
|
pass
|
11
12
|
|
12
13
|
|
13
|
-
class
|
14
|
+
class ProcOutputValidationError(Exception):
|
14
15
|
pass
|
15
16
|
|
16
17
|
|
17
|
-
class
|
18
|
-
|
18
|
+
class AgentFinalAnswerError(Exception):
|
19
|
+
def __init__(self, message: str | None = None) -> None:
|
20
|
+
super().__init__(
|
21
|
+
message or "Final answer tool call did not return a final answer message."
|
22
|
+
)
|
23
|
+
self.message = message
|
19
24
|
|
20
25
|
|
21
|
-
class
|
26
|
+
class WorkflowConstructionError(Exception):
|
22
27
|
pass
|
23
28
|
|
24
29
|
|
25
|
-
class
|
26
|
-
|
30
|
+
class PacketRoutingError(Exception):
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
selected_recipient: str,
|
34
|
+
allowed_recipients: list[str],
|
35
|
+
message: str | None = None,
|
36
|
+
) -> None:
|
37
|
+
default_message = (
|
38
|
+
f"Selected recipient '{selected_recipient}' is not in the allowed "
|
39
|
+
f"recipients: {allowed_recipients}"
|
40
|
+
)
|
41
|
+
super().__init__(message or default_message)
|
42
|
+
self.selected_recipient = selected_recipient
|
43
|
+
self.allowed_recipients = allowed_recipients
|
27
44
|
|
28
45
|
|
29
46
|
class SystemPromptBuilderError(Exception):
|
@@ -32,3 +49,46 @@ class SystemPromptBuilderError(Exception):
|
|
32
49
|
|
33
50
|
class InputPromptBuilderError(Exception):
|
34
51
|
pass
|
52
|
+
|
53
|
+
|
54
|
+
class PyJSONStringParsingError(Exception):
|
55
|
+
def __init__(self, s: str, message: str | None = None) -> None:
|
56
|
+
super().__init__(
|
57
|
+
message
|
58
|
+
or "Both ast.literal_eval and json.loads failed to parse the following "
|
59
|
+
f"JSON/Python string:\n{s}"
|
60
|
+
)
|
61
|
+
self.s = s
|
62
|
+
|
63
|
+
|
64
|
+
class JSONSchemaValidationError(Exception):
|
65
|
+
def __init__(self, s: str, schema: object, message: str | None = None) -> None:
|
66
|
+
super().__init__(
|
67
|
+
message
|
68
|
+
or f"JSON schema validation failed for:\n{s}\nExpected type: {schema}"
|
69
|
+
)
|
70
|
+
self.s = s
|
71
|
+
self.schema = schema
|
72
|
+
|
73
|
+
|
74
|
+
class LLMToolCallValidationError(Exception):
|
75
|
+
def __init__(
|
76
|
+
self, tool_name: str, tool_args: str, message: str | None = None
|
77
|
+
) -> None:
|
78
|
+
super().__init__(
|
79
|
+
message
|
80
|
+
or f"Failed to validate tool call '{tool_name}' with arguments:"
|
81
|
+
f"\n{tool_args}."
|
82
|
+
)
|
83
|
+
self.tool_name = tool_name
|
84
|
+
self.tool_args = tool_args
|
85
|
+
|
86
|
+
|
87
|
+
class LLMResponseValidationError(JSONSchemaValidationError):
|
88
|
+
def __init__(self, s: str, schema: object, message: str | None = None) -> None:
|
89
|
+
super().__init__(
|
90
|
+
s,
|
91
|
+
schema,
|
92
|
+
message
|
93
|
+
or f"Failed to validate LLM response:\n{s}\nExpected type: {schema}",
|
94
|
+
)
|
@@ -0,0 +1,106 @@
|
|
1
|
+
# pyright: reportUnusedImport=false
|
2
|
+
|
3
|
+
from litellm.types.utils import ChatCompletionMessageToolCall as LiteLLMToolCall
|
4
|
+
from litellm.types.utils import Choices as LiteLLMChoice
|
5
|
+
from litellm.types.utils import Function as LiteLLMFunction
|
6
|
+
from litellm.types.utils import Message as LiteLLMCompletionMessage
|
7
|
+
from litellm.types.utils import ModelResponse as LiteLLMCompletion
|
8
|
+
from litellm.types.utils import ModelResponseStream as LiteLLMCompletionChunk
|
9
|
+
from litellm.types.utils import StreamingChoices as LiteLLMChunkChoice
|
10
|
+
from litellm.types.utils import Usage as LiteLLMUsage
|
11
|
+
from openai._streaming import (
|
12
|
+
AsyncStream as OpenAIAsyncStream, # type: ignore[import] # noqa: PLC2701
|
13
|
+
)
|
14
|
+
from openai.types import CompletionUsage as OpenAIUsage
|
15
|
+
from openai.types.chat.chat_completion import ChatCompletion as OpenAICompletion
|
16
|
+
from openai.types.chat.chat_completion import (
|
17
|
+
ChoiceLogprobs as OpenAIChoiceLogprobs,
|
18
|
+
)
|
19
|
+
from openai.types.chat.chat_completion_assistant_message_param import (
|
20
|
+
ChatCompletionAssistantMessageParam as OpenAIAssistantMessageParam,
|
21
|
+
)
|
22
|
+
from openai.types.chat.chat_completion_chunk import (
|
23
|
+
ChatCompletionChunk as OpenAICompletionChunk,
|
24
|
+
)
|
25
|
+
from openai.types.chat.chat_completion_chunk import (
|
26
|
+
Choice as OpenAIChunkChoice,
|
27
|
+
)
|
28
|
+
from openai.types.chat.chat_completion_chunk import (
|
29
|
+
ChoiceDelta as OpenAIChunkChoiceDelta,
|
30
|
+
)
|
31
|
+
from openai.types.chat.chat_completion_chunk import (
|
32
|
+
ChoiceDeltaToolCall as OpenAIChunkChoiceDeltaToolCall,
|
33
|
+
)
|
34
|
+
from openai.types.chat.chat_completion_content_part_image_param import (
|
35
|
+
ChatCompletionContentPartImageParam as OpenAIContentPartImageParam,
|
36
|
+
)
|
37
|
+
from openai.types.chat.chat_completion_content_part_image_param import (
|
38
|
+
ImageURL as OpenAIImageURL,
|
39
|
+
)
|
40
|
+
from openai.types.chat.chat_completion_content_part_param import (
|
41
|
+
ChatCompletionContentPartParam as OpenAIContentPartParam,
|
42
|
+
)
|
43
|
+
from openai.types.chat.chat_completion_content_part_text_param import (
|
44
|
+
ChatCompletionContentPartTextParam as OpenAIContentPartTextParam,
|
45
|
+
)
|
46
|
+
from openai.types.chat.chat_completion_developer_message_param import (
|
47
|
+
ChatCompletionDeveloperMessageParam as OpenAIDeveloperMessageParam,
|
48
|
+
)
|
49
|
+
from openai.types.chat.chat_completion_function_message_param import (
|
50
|
+
ChatCompletionFunctionMessageParam as OpenAIFunctionMessageParam,
|
51
|
+
)
|
52
|
+
from openai.types.chat.chat_completion_message import (
|
53
|
+
ChatCompletionMessage as OpenAICompletionMessage,
|
54
|
+
)
|
55
|
+
from openai.types.chat.chat_completion_message_param import (
|
56
|
+
ChatCompletionMessageParam as OpenAIMessageParam,
|
57
|
+
)
|
58
|
+
from openai.types.chat.chat_completion_message_tool_call_param import (
|
59
|
+
ChatCompletionMessageToolCallParam as OpenAIToolCallParam,
|
60
|
+
)
|
61
|
+
from openai.types.chat.chat_completion_message_tool_call_param import (
|
62
|
+
Function as OpenAIToolCallFunction,
|
63
|
+
)
|
64
|
+
from openai.types.chat.chat_completion_named_tool_choice_param import (
|
65
|
+
ChatCompletionNamedToolChoiceParam as OpenAINamedToolChoiceParam,
|
66
|
+
)
|
67
|
+
from openai.types.chat.chat_completion_named_tool_choice_param import (
|
68
|
+
Function as OpenAINamedToolChoiceFunction,
|
69
|
+
)
|
70
|
+
from openai.types.chat.chat_completion_prediction_content_param import (
|
71
|
+
ChatCompletionPredictionContentParam as OpenAIPredictionContentParam,
|
72
|
+
)
|
73
|
+
from openai.types.chat.chat_completion_stream_options_param import (
|
74
|
+
ChatCompletionStreamOptionsParam as OpenAIStreamOptionsParam,
|
75
|
+
)
|
76
|
+
from openai.types.chat.chat_completion_system_message_param import (
|
77
|
+
ChatCompletionSystemMessageParam as OpenAISystemMessageParam,
|
78
|
+
)
|
79
|
+
from openai.types.chat.chat_completion_tool_choice_option_param import (
|
80
|
+
ChatCompletionToolChoiceOptionParam as OpenAIToolChoiceOptionParam,
|
81
|
+
)
|
82
|
+
from openai.types.chat.chat_completion_tool_message_param import (
|
83
|
+
ChatCompletionToolMessageParam as OpenAIToolMessageParam,
|
84
|
+
)
|
85
|
+
from openai.types.chat.chat_completion_tool_param import (
|
86
|
+
ChatCompletionToolParam as OpenAIToolParam,
|
87
|
+
)
|
88
|
+
from openai.types.chat.chat_completion_user_message_param import (
|
89
|
+
ChatCompletionUserMessageParam as OpenAIUserMessageParam,
|
90
|
+
)
|
91
|
+
from openai.types.chat.parsed_chat_completion import (
|
92
|
+
ParsedChatCompletion as OpenAIParsedCompletion,
|
93
|
+
)
|
94
|
+
from openai.types.chat.parsed_chat_completion import (
|
95
|
+
ParsedChatCompletionMessage as OpenAIParsedCompletionMessage,
|
96
|
+
)
|
97
|
+
from openai.types.chat.parsed_chat_completion import (
|
98
|
+
ParsedChoice as OpenAIParsedChoice,
|
99
|
+
)
|
100
|
+
from openai.types.shared_params.function_definition import (
|
101
|
+
FunctionDefinition as OpenAIFunctionDefinition,
|
102
|
+
)
|
103
|
+
|
104
|
+
from .lite_llm import LiteLLM, LiteLLMSettings
|
105
|
+
|
106
|
+
__all__ = ["LiteLLM", "LiteLLMSettings"]
|
@@ -0,0 +1,68 @@
|
|
1
|
+
from ..openai.completion_converters import from_api_completion_usage
|
2
|
+
from ..typing.completion_chunk import (
|
3
|
+
CompletionChunk,
|
4
|
+
CompletionChunkChoice,
|
5
|
+
CompletionChunkChoiceDelta,
|
6
|
+
CompletionChunkDeltaToolCall,
|
7
|
+
)
|
8
|
+
from . import LiteLLMChunkChoice, LiteLLMCompletionChunk
|
9
|
+
|
10
|
+
|
11
|
+
def from_api_completion_chunk(
|
12
|
+
api_completion_chunk: LiteLLMCompletionChunk, name: str | None = None
|
13
|
+
) -> CompletionChunk:
|
14
|
+
choices: list[CompletionChunkChoice] = []
|
15
|
+
|
16
|
+
for api_choice in api_completion_chunk.choices:
|
17
|
+
assert isinstance(api_choice, LiteLLMChunkChoice)
|
18
|
+
|
19
|
+
api_delta = api_choice.delta
|
20
|
+
|
21
|
+
delta = CompletionChunkChoiceDelta(
|
22
|
+
tool_calls=[
|
23
|
+
CompletionChunkDeltaToolCall(
|
24
|
+
id=tool_call.id,
|
25
|
+
index=tool_call.index,
|
26
|
+
tool_name=tool_call.function.name,
|
27
|
+
tool_arguments=tool_call.function.arguments,
|
28
|
+
)
|
29
|
+
for tool_call in (api_delta.tool_calls or [])
|
30
|
+
if tool_call.function
|
31
|
+
],
|
32
|
+
content=api_delta.content, # type: ignore[assignment, arg-type]
|
33
|
+
role=api_delta.role, # type: ignore[assignment, arg-type]
|
34
|
+
thinking_blocks=getattr(api_delta, "thinking_blocks", None),
|
35
|
+
annotations=getattr(api_delta, "annotations", None),
|
36
|
+
reasoning_content=getattr(api_delta, "reasoning_content", None),
|
37
|
+
provider_specific_fields=api_delta.provider_specific_fields,
|
38
|
+
refusal=getattr(api_delta, "refusal", None),
|
39
|
+
)
|
40
|
+
|
41
|
+
choice = CompletionChunkChoice(
|
42
|
+
delta=delta,
|
43
|
+
index=api_choice.index,
|
44
|
+
finish_reason=api_choice.finish_reason, # type: ignore[assignment, arg-type]
|
45
|
+
logprobs=getattr(api_choice, "logprobs", None),
|
46
|
+
)
|
47
|
+
|
48
|
+
choices.append(choice)
|
49
|
+
|
50
|
+
api_usage = getattr(api_completion_chunk, "usage", None)
|
51
|
+
usage = None
|
52
|
+
if api_usage is not None:
|
53
|
+
usage = from_api_completion_usage(api_usage)
|
54
|
+
hidden_params = getattr(api_completion_chunk, "_hidden_params", {})
|
55
|
+
usage.cost = getattr(hidden_params, "response_cost", None)
|
56
|
+
|
57
|
+
return CompletionChunk(
|
58
|
+
id=api_completion_chunk.id,
|
59
|
+
model=api_completion_chunk.model,
|
60
|
+
name=name,
|
61
|
+
created=api_completion_chunk.created,
|
62
|
+
system_fingerprint=api_completion_chunk.system_fingerprint,
|
63
|
+
choices=choices,
|
64
|
+
usage=usage,
|
65
|
+
provider_specific_fields=api_completion_chunk.provider_specific_fields,
|
66
|
+
hidden_params=api_completion_chunk._hidden_params, # type: ignore[union-attr]
|
67
|
+
response_ms=getattr(api_completion_chunk, "_response_ms", None),
|
68
|
+
)
|
@@ -0,0 +1,72 @@
|
|
1
|
+
from typing import cast
|
2
|
+
|
3
|
+
from ..typing.completion import Completion, CompletionChoice, Usage
|
4
|
+
from . import LiteLLMChoice, LiteLLMCompletion, LiteLLMUsage
|
5
|
+
from .message_converters import from_api_assistant_message
|
6
|
+
|
7
|
+
|
8
|
+
def from_api_completion_usage(api_usage: LiteLLMUsage) -> Usage:
|
9
|
+
reasoning_tokens = None
|
10
|
+
cached_tokens = None
|
11
|
+
|
12
|
+
if api_usage.completion_tokens_details is not None:
|
13
|
+
reasoning_tokens = api_usage.completion_tokens_details.reasoning_tokens
|
14
|
+
if api_usage.prompt_tokens_details is not None:
|
15
|
+
cached_tokens = api_usage.prompt_tokens_details.cached_tokens
|
16
|
+
|
17
|
+
input_tokens = api_usage.prompt_tokens - (cached_tokens or 0)
|
18
|
+
output_tokens = api_usage.completion_tokens # - (reasoning_tokens or 0)
|
19
|
+
|
20
|
+
return Usage(
|
21
|
+
input_tokens=input_tokens,
|
22
|
+
output_tokens=output_tokens,
|
23
|
+
reasoning_tokens=reasoning_tokens,
|
24
|
+
cached_tokens=cached_tokens,
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
def from_api_completion(
|
29
|
+
api_completion: LiteLLMCompletion, name: str | None = None
|
30
|
+
) -> Completion:
|
31
|
+
choices: list[CompletionChoice] = []
|
32
|
+
usage: Usage | None = None
|
33
|
+
|
34
|
+
for api_choice in api_completion.choices:
|
35
|
+
assert isinstance(api_choice, LiteLLMChoice)
|
36
|
+
|
37
|
+
message = from_api_assistant_message(api_choice.message, name=name)
|
38
|
+
|
39
|
+
choices.append(
|
40
|
+
CompletionChoice(
|
41
|
+
index=api_choice.index,
|
42
|
+
message=message,
|
43
|
+
finish_reason=api_choice.finish_reason, # type: ignore[assignment, arg-type]
|
44
|
+
logprobs=getattr(api_choice, "logprobs", None),
|
45
|
+
provider_specific_fields=getattr(
|
46
|
+
api_choice, "provider_specific_fields", None
|
47
|
+
),
|
48
|
+
)
|
49
|
+
)
|
50
|
+
|
51
|
+
api_usage = getattr(api_completion, "usage", None)
|
52
|
+
usage = None
|
53
|
+
if api_usage:
|
54
|
+
usage = from_api_completion_usage(cast("LiteLLMUsage", api_usage))
|
55
|
+
hidden_params = getattr(api_completion, "_hidden_params", {})
|
56
|
+
usage.cost = hidden_params.get("response_cost")
|
57
|
+
|
58
|
+
return Completion(
|
59
|
+
id=api_completion.id,
|
60
|
+
created=api_completion.created,
|
61
|
+
usage=usage,
|
62
|
+
choices=choices,
|
63
|
+
name=name,
|
64
|
+
system_fingerprint=api_completion.system_fingerprint,
|
65
|
+
model=api_completion.model,
|
66
|
+
hidden_params=api_completion._hidden_params, # type: ignore[union-attr]
|
67
|
+
response_ms=getattr(api_completion, "_response_ms", None),
|
68
|
+
)
|
69
|
+
|
70
|
+
|
71
|
+
def to_api_completion(completion: Completion) -> LiteLLMCompletion:
|
72
|
+
raise NotImplementedError
|