llmeter 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llmeter/__init__.py +0 -0
- llmeter/endpoints/__init__.py +13 -0
- llmeter/endpoints/base.py +232 -0
- llmeter/endpoints/bedrock.py +192 -0
- llmeter/endpoints/litellm.py +165 -0
- llmeter/endpoints/openai.py +78 -0
- llmeter/endpoints/sagemaker.py +316 -0
- llmeter/experiments.py +163 -0
- llmeter/plotting.py +143 -0
- llmeter/prompt_utils.py +190 -0
- llmeter/results.py +228 -0
- llmeter/runner.py +511 -0
- llmeter/tokenizers.py +175 -0
- llmeter/utils.py +32 -0
- llmeter-0.1.0.dist-info/LICENSE +175 -0
- llmeter-0.1.0.dist-info/METADATA +125 -0
- llmeter-0.1.0.dist-info/NOTICE +1 -0
- llmeter-0.1.0.dist-info/RECORD +19 -0
- llmeter-0.1.0.dist-info/WHEEL +4 -0
llmeter/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import importlib.util
|
|
2
|
+
|
|
3
|
+
from .base import Endpoint, InvocationResponse # noqa: F401
|
|
4
|
+
from .bedrock import BedrockConverse, BedrockConverseStream # noqa: F401
|
|
5
|
+
from .sagemaker import SageMakerEndpoint, SageMakerStreamEndpoint # noqa: F401
|
|
6
|
+
|
|
7
|
+
spec = importlib.util.find_spec("openai")
|
|
8
|
+
if spec:
|
|
9
|
+
from .openai import OpenAIEndpoint, OpenAICompletionEndpoint # noqa: F401
|
|
10
|
+
|
|
11
|
+
spec = importlib.util.find_spec("litellm")
|
|
12
|
+
if spec:
|
|
13
|
+
from .litellm import LiteLLM, LiteLLMStreaming # noqa: F401
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import importlib
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from dataclasses import asdict, dataclass
|
|
9
|
+
from typing import Dict, TypeVar
|
|
10
|
+
from uuid import uuid4
|
|
11
|
+
|
|
12
|
+
from upath import UPath as Path
|
|
13
|
+
|
|
14
|
+
Self = TypeVar(
|
|
15
|
+
"Self", bound="Endpoint"
|
|
16
|
+
) # for python >= 3.11 can be replaced with direct import of `Self`
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class InvocationResponse:
|
|
21
|
+
"""
|
|
22
|
+
A class representing a invocation result.
|
|
23
|
+
|
|
24
|
+
Attributes:
|
|
25
|
+
response_text (str): The invocation output.
|
|
26
|
+
id (str): A unique identifier for the invocation.
|
|
27
|
+
time_to_last_token (float): The time taken to generate the response in seconds.
|
|
28
|
+
time_to_first_token (float): The time taken to receive the first token of the response in seconds.
|
|
29
|
+
num_tokens_output (Optional[int]): The number of tokens in the response.
|
|
30
|
+
num_tokens_input (Optional[int]): The number of tokens in the invocation payload.
|
|
31
|
+
input_prompt (str): The input prompt used in the invocation.
|
|
32
|
+
time_per_output_token (float): The average time taken to generate each token in the response.
|
|
33
|
+
error (str): Any error that occurred during invocation.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
response_text: str | None
|
|
37
|
+
id: str | None = None
|
|
38
|
+
input_prompt: str | None = None
|
|
39
|
+
time_to_first_token: float | None = None
|
|
40
|
+
time_to_last_token: float | None = None
|
|
41
|
+
num_tokens_input: int | None = None
|
|
42
|
+
num_tokens_output: int | None = None
|
|
43
|
+
time_per_output_token: float | None = None
|
|
44
|
+
error: str | None = None
|
|
45
|
+
|
|
46
|
+
def to_json(self, **kwargs) -> str:
|
|
47
|
+
return json.dumps(self.__dict__, **kwargs)
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def error_output(
|
|
51
|
+
input_prompt: str | None = None, error=None, id: str | None = None
|
|
52
|
+
):
|
|
53
|
+
return InvocationResponse(
|
|
54
|
+
id=id or uuid4().hex,
|
|
55
|
+
response_text=None,
|
|
56
|
+
input_prompt=input_prompt,
|
|
57
|
+
time_to_last_token=None,
|
|
58
|
+
error="invocation failed" if error is None else str(error),
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def __repr__(self):
|
|
62
|
+
return self.to_json(default=str)
|
|
63
|
+
|
|
64
|
+
def __str__(self):
|
|
65
|
+
return self.to_json(indent=4, default=str)
|
|
66
|
+
|
|
67
|
+
def to_dict(self):
|
|
68
|
+
return asdict(self)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class Endpoint(ABC):
|
|
72
|
+
"""
|
|
73
|
+
An abstract base class for endpoint implementations.
|
|
74
|
+
|
|
75
|
+
This class defines the basic structure and interface for all endpoint classes.
|
|
76
|
+
It provides abstract methods that must be implemented by subclasses.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
@abstractmethod
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
endpoint_name: str,
|
|
83
|
+
model_id: str,
|
|
84
|
+
provider: str,
|
|
85
|
+
):
|
|
86
|
+
"""
|
|
87
|
+
Initialize the BaseEndpoint.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
endpoint_name (str): The name of the endpoint.
|
|
91
|
+
model_id (str): The identifier of the model associated with this endpoint.
|
|
92
|
+
provider (str): The provider of the endpoint.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
None
|
|
96
|
+
"""
|
|
97
|
+
self.endpoint_name = endpoint_name
|
|
98
|
+
self.model_id = model_id
|
|
99
|
+
self.provider = provider
|
|
100
|
+
|
|
101
|
+
@abstractmethod
|
|
102
|
+
def invoke(self, payload: Dict) -> InvocationResponse:
|
|
103
|
+
"""
|
|
104
|
+
Invoke the endpoint with the given payload.
|
|
105
|
+
|
|
106
|
+
This method must be implemented by subclasses to define how the endpoint
|
|
107
|
+
is invoked and how the response is processed.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
payload (Dict): The input payload for the model.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
InvocationResponse: An object containing the model's response and associated metrics.
|
|
114
|
+
|
|
115
|
+
Raises:
|
|
116
|
+
NotImplementedError: If the method is not implemented by a subclass.
|
|
117
|
+
"""
|
|
118
|
+
raise NotImplementedError
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def create_payload(*args, **kwargs):
|
|
122
|
+
"""
|
|
123
|
+
Create a payload for the endpoint invocation.
|
|
124
|
+
|
|
125
|
+
This static method should be implemented by subclasses to define
|
|
126
|
+
how the payload is created based on the given arguments.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
*args: Variable length argument list.
|
|
130
|
+
**kwargs: Arbitrary keyword arguments.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
NotImplemented: This method returns NotImplemented in the base class.
|
|
134
|
+
"""
|
|
135
|
+
return NotImplemented
|
|
136
|
+
|
|
137
|
+
@classmethod
|
|
138
|
+
def __subclasshook__(cls, C):
|
|
139
|
+
"""
|
|
140
|
+
Determine if a class is considered a subclass of BaseEndpoint.
|
|
141
|
+
|
|
142
|
+
This method is used to implement a custom subclass check. A class
|
|
143
|
+
is considered a subclass of BaseEndpoint if it has an 'invoke' method.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
C: The class to check.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
bool or NotImplemented: True if the class is a subclass, False if it isn't,
|
|
150
|
+
or NotImplemented if the check is inconclusive.
|
|
151
|
+
"""
|
|
152
|
+
if cls is Endpoint:
|
|
153
|
+
if any("invoke" in B.__dict__ for B in C.__mro__):
|
|
154
|
+
return True
|
|
155
|
+
return NotImplemented
|
|
156
|
+
|
|
157
|
+
def save(self, output_path: os.PathLike) -> os.PathLike:
|
|
158
|
+
"""
|
|
159
|
+
Save the endpoint configuration to a JSON file.
|
|
160
|
+
|
|
161
|
+
This method serializes the endpoint's configuration (excluding private attributes)
|
|
162
|
+
to a JSON file at the specified path.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
output_path (str | UPath): The path where the configuration file will be saved.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
None
|
|
169
|
+
"""
|
|
170
|
+
output_path = Path(output_path)
|
|
171
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
172
|
+
with output_path.open("w") as f:
|
|
173
|
+
endpoint_conf = self.to_dict()
|
|
174
|
+
json.dump(endpoint_conf, f, indent=4, default=str)
|
|
175
|
+
return output_path
|
|
176
|
+
|
|
177
|
+
def to_dict(self) -> Dict:
|
|
178
|
+
"""
|
|
179
|
+
Convert the endpoint configuration to a dictionary.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Dict: A dictionary representation of the endpoint configuration.
|
|
183
|
+
"""
|
|
184
|
+
endpoint_conf = {k: v for k, v in vars(self).items() if not k.startswith("_")}
|
|
185
|
+
endpoint_conf["endpoint_type"] = self.__class__.__name__
|
|
186
|
+
return endpoint_conf
|
|
187
|
+
|
|
188
|
+
@classmethod
|
|
189
|
+
def load_from_file(cls, input_path: os.PathLike) -> Self:
|
|
190
|
+
"""
|
|
191
|
+
Load an endpoint configuration from a JSON file.
|
|
192
|
+
|
|
193
|
+
This class method reads a JSON file containing an endpoint configuration,
|
|
194
|
+
determines the appropriate endpoint class, and instantiates it with the
|
|
195
|
+
loaded configuration.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
input_path (str|UPath): The path to the JSON configuration file.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
Endpoint: An instance of the appropriate endpoint class, initialized
|
|
202
|
+
with the configuration from the file.
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
input_path = Path(input_path)
|
|
206
|
+
with input_path.open("r") as f:
|
|
207
|
+
data = json.load(f)
|
|
208
|
+
endpoint_type = data.pop("endpoint_type")
|
|
209
|
+
endpoint_module = importlib.import_module("llmeter.endpoints")
|
|
210
|
+
endpoint_class = getattr(endpoint_module, endpoint_type)
|
|
211
|
+
return endpoint_class(**data)
|
|
212
|
+
|
|
213
|
+
@classmethod
|
|
214
|
+
def load(cls, endpoint_config: Dict) -> Self: # type: ignore
|
|
215
|
+
"""
|
|
216
|
+
Load an endpoint configuration from a dictionary.
|
|
217
|
+
|
|
218
|
+
This class method reads a dictionary containing an endpoint configuration,
|
|
219
|
+
determines the appropriate endpoint class, and instantiates it with the
|
|
220
|
+
loaded configuration.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
data (Dict): A dictionary containing the endpoint configuration.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
Endpoint: An instance of the appropriate endpoint class, initialized
|
|
227
|
+
with the configuration from the dictionary.
|
|
228
|
+
"""
|
|
229
|
+
endpoint_type = endpoint_config.pop("endpoint_type")
|
|
230
|
+
endpoint_module = importlib.import_module("llmeter.endpoints")
|
|
231
|
+
endpoint_class = getattr(endpoint_module, endpoint_type)
|
|
232
|
+
return endpoint_class(**endpoint_config)
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
import time
|
|
6
|
+
from typing import Dict, Sequence
|
|
7
|
+
from uuid import uuid4
|
|
8
|
+
|
|
9
|
+
import boto3
|
|
10
|
+
import jmespath
|
|
11
|
+
from botocore.exceptions import ClientError
|
|
12
|
+
|
|
13
|
+
from .base import Endpoint, InvocationResponse
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BedrockBase(Endpoint):
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
model_id: str,
|
|
22
|
+
endpoint_name: str | None = None,
|
|
23
|
+
region: str | None = None,
|
|
24
|
+
inference_config: Dict | None = None,
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Base class for Amazon Bedrock endpoints.
|
|
28
|
+
|
|
29
|
+
This class provides the foundation for interacting with Amazon Bedrock services,
|
|
30
|
+
including client initialization and payload handling.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
super().__init__(
|
|
34
|
+
model_id=model_id, endpoint_name=endpoint_name or "", provider="bedrock"
|
|
35
|
+
)
|
|
36
|
+
"""
|
|
37
|
+
Initialize the BedrockBase instance.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
model_id (str): The ID of the model to use.
|
|
41
|
+
region (str | None, optional): The AWS region to use. If None, uses the default session region.
|
|
42
|
+
inference_config (Dict | None, optional): Configuration for inference.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
self.endpoint_name = "amazon bedrock"
|
|
46
|
+
|
|
47
|
+
self.region = region or boto3.session.Session().region_name
|
|
48
|
+
logger.info(f"Using AWS region: {self.region}")
|
|
49
|
+
|
|
50
|
+
self._bedrock_client = boto3.client("bedrock-runtime", region_name=self.region)
|
|
51
|
+
self._inference_config = inference_config
|
|
52
|
+
|
|
53
|
+
def _parse_payload(self, payload):
|
|
54
|
+
"""
|
|
55
|
+
Parse the payload to extract text content.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
payload (dict): The payload containing messages.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
str: Concatenated text content from the messages.
|
|
62
|
+
"""
|
|
63
|
+
jpath = "[:].content[:].text"
|
|
64
|
+
messages = payload.get("messages")
|
|
65
|
+
return "\n".join([k for j in jmespath.search(jpath, messages) for k in j])
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
def create_payload(
|
|
69
|
+
user_message: str | Sequence[str], max_tokens: int = 256, **kwargs
|
|
70
|
+
):
|
|
71
|
+
"""
|
|
72
|
+
Create a payload for the Bedrock Converse API request.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
user_message (str | Sequence[str]): The user's message or a sequence of messages.
|
|
76
|
+
max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 256.
|
|
77
|
+
**kwargs: Additional keyword arguments to include in the payload.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
dict: The formatted payload for the Bedrock API request.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
if isinstance(user_message, str):
|
|
84
|
+
user_message = [user_message]
|
|
85
|
+
payload: dict = {
|
|
86
|
+
"messages": [
|
|
87
|
+
{"role": "user", "content": [{"text": k}]} for k in user_message
|
|
88
|
+
],
|
|
89
|
+
}
|
|
90
|
+
payload.update(kwargs)
|
|
91
|
+
if payload.get("inferenceConfig") is None:
|
|
92
|
+
payload["inferenceConfig"] = {}
|
|
93
|
+
|
|
94
|
+
payload["inferenceConfig"] = {
|
|
95
|
+
**payload["inferenceConfig"],
|
|
96
|
+
"maxTokens": max_tokens,
|
|
97
|
+
}
|
|
98
|
+
return payload
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class BedrockConverse(BedrockBase):
|
|
102
|
+
def _parse_converse_response(self, response: Dict) -> InvocationResponse:
|
|
103
|
+
output_text = response["output"]["message"]["content"][0]["text"]
|
|
104
|
+
usage = response.get("usage", {})
|
|
105
|
+
# metrics = response.get("metrics", {})
|
|
106
|
+
|
|
107
|
+
return InvocationResponse(
|
|
108
|
+
id=uuid4().hex,
|
|
109
|
+
response_text=output_text,
|
|
110
|
+
num_tokens_input=usage.get("inputTokens"),
|
|
111
|
+
num_tokens_output=usage.get("outputTokens"),
|
|
112
|
+
# time_to_last_token=metrics.get("latencyMs") / 1e3,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def invoke(self, payload: Dict, **kwargs) -> InvocationResponse:
|
|
116
|
+
payload = {**kwargs, **payload}
|
|
117
|
+
if payload.get("inferenceConfig") is None:
|
|
118
|
+
payload["inferenceConfig"] = self._inference_config or {}
|
|
119
|
+
|
|
120
|
+
payload["modelId"] = self.model_id
|
|
121
|
+
try:
|
|
122
|
+
start_t = time.perf_counter()
|
|
123
|
+
client_response = self._bedrock_client.converse(**payload)
|
|
124
|
+
time_to_last_token = time.perf_counter() - start_t
|
|
125
|
+
except (ClientError, Exception) as e:
|
|
126
|
+
logger.error(e)
|
|
127
|
+
return InvocationResponse.error_output(id=uuid4().hex, error=str(e))
|
|
128
|
+
response = self._parse_converse_response(client_response) # type: ignore
|
|
129
|
+
response.input_prompt = self._parse_payload(payload)
|
|
130
|
+
response.time_to_last_token = time_to_last_token
|
|
131
|
+
return response
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class BedrockConverseStream(BedrockConverse):
|
|
135
|
+
def invoke(self, payload: Dict, **kwargs) -> InvocationResponse:
|
|
136
|
+
payload = {**kwargs, **payload}
|
|
137
|
+
if payload.get("inferenceConfig") is None:
|
|
138
|
+
payload["inferenceConfig"] = self._inference_config or {}
|
|
139
|
+
|
|
140
|
+
payload["modelId"] = self.model_id
|
|
141
|
+
start_t = time.perf_counter()
|
|
142
|
+
try:
|
|
143
|
+
client_response = self._bedrock_client.converse_stream(**payload)
|
|
144
|
+
except (ClientError, Exception) as e:
|
|
145
|
+
logger.error(e)
|
|
146
|
+
return InvocationResponse.error_output(id=uuid4().hex, error=str(e))
|
|
147
|
+
response = self._parse_conversation_stream(client_response, start_t) # type: ignore
|
|
148
|
+
response.input_prompt = self._parse_payload(payload)
|
|
149
|
+
return response
|
|
150
|
+
|
|
151
|
+
def _parse_conversation_stream(
|
|
152
|
+
self, client_response: Dict, start_t: float
|
|
153
|
+
) -> InvocationResponse:
|
|
154
|
+
time_flag = True
|
|
155
|
+
time_to_first_token = None
|
|
156
|
+
output_text = ""
|
|
157
|
+
for chunk in client_response["stream"]:
|
|
158
|
+
if "contentBlockDelta" in chunk:
|
|
159
|
+
output_text += chunk["contentBlockDelta"]["delta"].get("text") or ""
|
|
160
|
+
if time_flag:
|
|
161
|
+
time_to_first_token = time.perf_counter() - start_t
|
|
162
|
+
time_flag = False
|
|
163
|
+
|
|
164
|
+
if "contentBlockStop" in chunk:
|
|
165
|
+
time_to_last_token = time.perf_counter() - start_t
|
|
166
|
+
|
|
167
|
+
if "metadata" in chunk:
|
|
168
|
+
metadata = chunk["metadata"]
|
|
169
|
+
|
|
170
|
+
response = InvocationResponse(
|
|
171
|
+
id=uuid4().hex,
|
|
172
|
+
response_text=output_text,
|
|
173
|
+
time_to_last_token=time_to_last_token,
|
|
174
|
+
time_to_first_token=time_to_first_token,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
if metadata:
|
|
178
|
+
# time_to_last_token = metadata.get("metrics", {}).get("latencyMs")
|
|
179
|
+
usage = metadata.get("usage", {})
|
|
180
|
+
response.num_tokens_input = usage.get("inputTokens")
|
|
181
|
+
response.num_tokens_output = usage.get("outputTokens")
|
|
182
|
+
if (
|
|
183
|
+
response.num_tokens_output
|
|
184
|
+
and time_to_last_token
|
|
185
|
+
and time_to_first_token
|
|
186
|
+
):
|
|
187
|
+
generation_time = time_to_last_token - time_to_first_token
|
|
188
|
+
response.time_per_output_token = (response.num_tokens_output - 1) and (
|
|
189
|
+
generation_time / (response.num_tokens_output - 1)
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
return response
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import time
|
|
8
|
+
from typing import Sequence
|
|
9
|
+
from uuid import uuid4
|
|
10
|
+
|
|
11
|
+
import litellm
|
|
12
|
+
from litellm import CustomStreamWrapper, completion
|
|
13
|
+
from litellm.types.utils import ModelResponse
|
|
14
|
+
from litellm.utils import get_llm_provider # type: ignore
|
|
15
|
+
|
|
16
|
+
from llmeter.endpoints import Endpoint, InvocationResponse
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
litellm.json_logs = True # type: ignore
|
|
21
|
+
litellm.turn_off_message_logging = True
|
|
22
|
+
litellm.suppress_debug_info = True
|
|
23
|
+
|
|
24
|
+
os.environ["LITELLM_LOG"] = "CRITICAL"
|
|
25
|
+
os.environ["LITELLM_DONT_SHOW_FEEDBACK_BOX"] = "true"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class LiteLLMBase(Endpoint):
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
litellm_model: str,
|
|
32
|
+
model_id: str | None = None,
|
|
33
|
+
):
|
|
34
|
+
self.litellm_model = litellm_model
|
|
35
|
+
model_id_inferred, provider, _, _ = get_llm_provider(litellm_model)
|
|
36
|
+
|
|
37
|
+
logger.info(f"Using model {model_id_inferred} from provider {provider}")
|
|
38
|
+
super().__init__(
|
|
39
|
+
model_id=model_id or model_id_inferred,
|
|
40
|
+
provider=provider,
|
|
41
|
+
endpoint_name=model_id_inferred,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def _parse_payload(self, payload):
|
|
45
|
+
return json.dumps(payload.get("messages"))
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
def create_payload(
|
|
49
|
+
user_message: str | Sequence[str],
|
|
50
|
+
max_tokens: int = 256,
|
|
51
|
+
system_message: str | None = None,
|
|
52
|
+
**kwargs,
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
Create a payload for the LiteLLM `completion()` request.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
user_message (str | Sequence[str]): The user's message or a sequence of messages.
|
|
59
|
+
max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 256.
|
|
60
|
+
**kwargs: Additional keyword arguments to include in the payload.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
dict: The formatted payload for the Bedrock API request.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
if isinstance(user_message, str):
|
|
67
|
+
user_message = [user_message]
|
|
68
|
+
payload = {
|
|
69
|
+
"messages": [{"role": "user", "content": k} for k in user_message],
|
|
70
|
+
"max_tokens": max_tokens,
|
|
71
|
+
}
|
|
72
|
+
payload.update(kwargs)
|
|
73
|
+
if system_message:
|
|
74
|
+
payload["messages"].append({"role": "system", "content": system_message})
|
|
75
|
+
return payload
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class LiteLLM(LiteLLMBase):
|
|
79
|
+
def invoke(self, payload, **kwargs):
|
|
80
|
+
try:
|
|
81
|
+
response = completion(model=self.litellm_model, **payload, **kwargs)
|
|
82
|
+
assert isinstance(response, ModelResponse)
|
|
83
|
+
response = self._parse_converse_response(response)
|
|
84
|
+
response.input_prompt = self._parse_payload(payload)
|
|
85
|
+
return response
|
|
86
|
+
|
|
87
|
+
except Exception as e:
|
|
88
|
+
logger.exception(e)
|
|
89
|
+
return InvocationResponse.error_output(
|
|
90
|
+
id=uuid4().hex, error=str(e), input_prompt=self._parse_payload(payload)
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def _parse_converse_response(
|
|
94
|
+
self, client_response: ModelResponse
|
|
95
|
+
) -> InvocationResponse:
|
|
96
|
+
response = InvocationResponse(
|
|
97
|
+
id=client_response.id,
|
|
98
|
+
response_text=client_response.choices[0].message.content, # type: ignore
|
|
99
|
+
)
|
|
100
|
+
try:
|
|
101
|
+
usage = client_response.usage # type: ignore
|
|
102
|
+
response.num_tokens_input = usage.prompt_tokens
|
|
103
|
+
response.num_tokens_output = usage.completion_tokens
|
|
104
|
+
except AttributeError:
|
|
105
|
+
pass
|
|
106
|
+
|
|
107
|
+
return response
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class LiteLLMStreaming(LiteLLMBase):
|
|
111
|
+
def invoke(self, payload, **kwargs):
|
|
112
|
+
if ("stream" not in kwargs) or ("stream" not in payload):
|
|
113
|
+
kwargs["stream"] = True
|
|
114
|
+
|
|
115
|
+
if ("stream_options" not in kwargs) or ("stream_options" not in payload):
|
|
116
|
+
kwargs["stream_options"] = {"include_usage": True}
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
start_t = time.perf_counter()
|
|
120
|
+
response = completion(model=self.litellm_model, **payload, **kwargs)
|
|
121
|
+
except Exception as e:
|
|
122
|
+
logger.exception(e)
|
|
123
|
+
return InvocationResponse.error_output(
|
|
124
|
+
id=uuid4().hex, error=str(e), input_prompt=self._parse_payload(payload)
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
assert isinstance(response, CustomStreamWrapper)
|
|
128
|
+
response = self._parse_stream(response, start_t)
|
|
129
|
+
response.input_prompt = self._parse_payload(payload)
|
|
130
|
+
return response
|
|
131
|
+
|
|
132
|
+
def _parse_stream(
|
|
133
|
+
self, client_response: CustomStreamWrapper, start_t: float
|
|
134
|
+
) -> InvocationResponse:
|
|
135
|
+
usage = None
|
|
136
|
+
time_flag = True
|
|
137
|
+
time_to_first_token = None
|
|
138
|
+
output_text = ""
|
|
139
|
+
for chunk in client_response:
|
|
140
|
+
output_text += chunk.choices[0].delta.content or "" # type: ignore
|
|
141
|
+
if time_flag:
|
|
142
|
+
time_to_first_token = time.perf_counter() - start_t
|
|
143
|
+
time_flag = False
|
|
144
|
+
id = chunk.id
|
|
145
|
+
try:
|
|
146
|
+
usage = chunk.usage # type: ignore
|
|
147
|
+
except AttributeError:
|
|
148
|
+
continue
|
|
149
|
+
|
|
150
|
+
time_to_last_token = time.perf_counter() - start_t
|
|
151
|
+
|
|
152
|
+
response = InvocationResponse(
|
|
153
|
+
id=id,
|
|
154
|
+
response_text=output_text,
|
|
155
|
+
num_tokens_input=usage and usage.prompt_tokens,
|
|
156
|
+
num_tokens_output=usage and usage.completion_tokens,
|
|
157
|
+
time_to_first_token=time_to_first_token,
|
|
158
|
+
time_to_last_token=time_to_last_token,
|
|
159
|
+
)
|
|
160
|
+
if response.num_tokens_output and time_to_last_token and time_to_first_token:
|
|
161
|
+
generation_time = time_to_last_token - time_to_first_token
|
|
162
|
+
response.time_per_output_token = (response.num_tokens_output - 1) and (
|
|
163
|
+
generation_time / (response.num_tokens_output - 1)
|
|
164
|
+
)
|
|
165
|
+
return response
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Dict, Sequence
|
|
6
|
+
from uuid import uuid4
|
|
7
|
+
|
|
8
|
+
import jmespath
|
|
9
|
+
from openai import APIConnectionError, OpenAI
|
|
10
|
+
from openai.types.chat import ChatCompletion
|
|
11
|
+
|
|
12
|
+
from .base import Endpoint, InvocationResponse
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class OpenAIEndpoint(Endpoint):
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
model_id: str,
|
|
21
|
+
endpoint_name: str = "openai",
|
|
22
|
+
api_key: str | None = None,
|
|
23
|
+
*args,
|
|
24
|
+
**kwargs,
|
|
25
|
+
):
|
|
26
|
+
super().__init__(
|
|
27
|
+
endpoint_name,
|
|
28
|
+
model_id,
|
|
29
|
+
*args,
|
|
30
|
+
**kwargs,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
self._client = OpenAI(api_key=api_key, **kwargs)
|
|
34
|
+
|
|
35
|
+
def _parse_payload(self, payload):
|
|
36
|
+
jpath = "[:].content"
|
|
37
|
+
messages = payload.get("messages")
|
|
38
|
+
return "\n".join([k for j in jmespath.search(jpath, messages) for k in j])
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def create_payload(
|
|
42
|
+
user_message: str | Sequence[str], max_tokens: int = 256, **kwargs
|
|
43
|
+
):
|
|
44
|
+
if isinstance(user_message, str):
|
|
45
|
+
user_message = [user_message]
|
|
46
|
+
payload = {
|
|
47
|
+
"messages": [{"role": "user", "content": k} for k in user_message],
|
|
48
|
+
"max_tokens": max_tokens,
|
|
49
|
+
}
|
|
50
|
+
payload.update(kwargs)
|
|
51
|
+
return payload
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class OpenAICompletionEndpoint(OpenAIEndpoint):
|
|
55
|
+
def invoke(self, payload: Dict, **kwargs) -> InvocationResponse:
|
|
56
|
+
payload = {**kwargs, **payload}
|
|
57
|
+
|
|
58
|
+
payload["model"] = self.model_id
|
|
59
|
+
try:
|
|
60
|
+
client_response: ChatCompletion = self._client.chat.completions.create(
|
|
61
|
+
**payload
|
|
62
|
+
)
|
|
63
|
+
except (APIConnectionError, Exception) as e:
|
|
64
|
+
logger.error(e)
|
|
65
|
+
return InvocationResponse.error_output(id=uuid4().hex, error=str(e))
|
|
66
|
+
response = self._parse_converse_response(client_response)
|
|
67
|
+
response.input_prompt = self._parse_payload(payload)
|
|
68
|
+
return response
|
|
69
|
+
|
|
70
|
+
def _parse_converse_response(self, client_response: ChatCompletion):
|
|
71
|
+
usage = client_response.usage
|
|
72
|
+
|
|
73
|
+
return InvocationResponse(
|
|
74
|
+
id=client_response.id,
|
|
75
|
+
response_text=client_response.choices[0].message.content,
|
|
76
|
+
num_tokens_input=usage and usage.prompt_tokens,
|
|
77
|
+
num_tokens_output=usage and usage.completion_tokens,
|
|
78
|
+
)
|