llmstudio 0.2.2__tar.gz → 0.2.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llmstudio-0.2.3/MANIFEST.in +1 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/PKG-INFO +1 -1
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/__init__.py +1 -1
- llmstudio-0.2.3/llmstudio/engine/providers/__init__.py +18 -0
- llmstudio-0.2.3/llmstudio/engine/providers/base_provider.py +70 -0
- llmstudio-0.2.3/llmstudio/engine/providers/bedrock.py +313 -0
- llmstudio-0.2.3/llmstudio/engine/providers/openai.py +233 -0
- llmstudio-0.2.3/llmstudio/engine/providers/vertexai.py +312 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/models/models.py +113 -2
- llmstudio-0.2.3/llmstudio/ui/build/asset-manifest.json +19 -0
- llmstudio-0.2.3/llmstudio/ui/build/fonts/Satoshi-Black.woff +0 -0
- llmstudio-0.2.3/llmstudio/ui/build/fonts/Satoshi-Bold.woff +0 -0
- llmstudio-0.2.3/llmstudio/ui/build/fonts/Satoshi-Light.woff +0 -0
- llmstudio-0.2.3/llmstudio/ui/build/fonts/Satoshi-Medium.woff +0 -0
- llmstudio-0.2.3/llmstudio/ui/build/fonts/Satoshi-Regular.woff +0 -0
- llmstudio-0.2.3/llmstudio/ui/build/fonts/VioletSans-Regular.woff +0 -0
- llmstudio-0.2.3/llmstudio/ui/build/images/claudio.jpg +0 -0
- llmstudio-0.2.3/llmstudio/ui/build/images/icon.png +0 -0
- llmstudio-0.2.3/llmstudio/ui/build/index.html +1 -0
- llmstudio-0.2.3/llmstudio/ui/build/manifest.json +25 -0
- llmstudio-0.2.3/llmstudio/ui/build/robots.txt +3 -0
- llmstudio-0.2.3/llmstudio/ui/build/static/css/main.0342ffa4.css +4 -0
- llmstudio-0.2.3/llmstudio/ui/build/static/css/main.0342ffa4.css.map +1 -0
- llmstudio-0.2.3/llmstudio/ui/build/static/js/main.7337aa4e.js +3 -0
- llmstudio-0.2.3/llmstudio/ui/build/static/js/main.7337aa4e.js.LICENSE.txt +83 -0
- llmstudio-0.2.3/llmstudio/ui/build/static/js/main.7337aa4e.js.map +1 -0
- llmstudio-0.2.3/llmstudio/ui/build/static/media/Satoshi-Black.4261d202e1e9410db1bf.woff +0 -0
- llmstudio-0.2.3/llmstudio/ui/build/static/media/Satoshi-Bold.a875ff682ee232938607.woff +0 -0
- llmstudio-0.2.3/llmstudio/ui/build/static/media/Satoshi-Light.67e7fa77f107df3491b6.woff +0 -0
- llmstudio-0.2.3/llmstudio/ui/build/static/media/Satoshi-Medium.2419b46c96ed15331ba2.woff +0 -0
- llmstudio-0.2.3/llmstudio/ui/build/static/media/Satoshi-Regular.ca3da5fd2b609836ef69.woff +0 -0
- llmstudio-0.2.3/llmstudio/ui/build/static/media/VioletSans-Regular.425614770e8617faebdd.woff +0 -0
- llmstudio-0.2.3/llmstudio/ui/build/svg/ai.svg +10 -0
- llmstudio-0.2.3/llmstudio/ui/build/svg/arrow.svg +3 -0
- llmstudio-0.2.3/llmstudio/ui/build/svg/home.svg +12 -0
- llmstudio-0.2.3/llmstudio/ui/build/svg/load.svg +4 -0
- llmstudio-0.2.3/llmstudio/ui/build/svg/magic.svg +10 -0
- llmstudio-0.2.3/llmstudio/ui/build/svg/play.svg +11 -0
- llmstudio-0.2.3/llmstudio/ui/build/svg/playground.svg +8 -0
- llmstudio-0.2.3/llmstudio/ui/build/svg/plus.svg +9 -0
- llmstudio-0.2.3/llmstudio/ui/build/svg/prompt.svg +7 -0
- llmstudio-0.2.3/llmstudio/ui/build/svg/settings.svg +15 -0
- llmstudio-0.2.3/llmstudio/ui/build/svg/sparkles.svg +3 -0
- llmstudio-0.2.3/llmstudio/utils/__init__.py +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio.egg-info/PKG-INFO +1 -1
- llmstudio-0.2.3/llmstudio.egg-info/SOURCES.txt +68 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/setup.py +6 -18
- llmstudio-0.2.2/llmstudio.egg-info/SOURCES.txt +0 -27
- {llmstudio-0.2.2 → llmstudio-0.2.3}/LICENSE +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/README.md +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/cli.py +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/client.py +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/engine/__init__.py +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/engine/config.py +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/engine/constants.py +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/engine/utils.py +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/models/__init__.py +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/models/bedrock.py +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/models/openai.py +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/models/vertexai.py +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/ui/__init__.py +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/utils/rest_utils.py +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/validators/__init__.py +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/validators/bedrock.py +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/validators/openai.py +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/validators/vertexai.py +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio.egg-info/dependency_links.txt +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio.egg-info/entry_points.txt +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio.egg-info/requires.txt +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio.egg-info/top_level.txt +0 -0
- {llmstudio-0.2.2 → llmstudio-0.2.3}/setup.cfg +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
recursive-include llmstudio/ui/build *
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from llmstudio.engine.config import Provider
|
|
2
|
+
from llmstudio.engine.providers.base_provider import BaseProvider
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def get_provider(provider: Provider) -> BaseProvider:
|
|
6
|
+
from llmstudio.engine.providers.bedrock import BedrockProvider
|
|
7
|
+
from llmstudio.engine.providers.openai import OpenAIProvider
|
|
8
|
+
from llmstudio.engine.providers.vertexai import VertexAIProvider
|
|
9
|
+
|
|
10
|
+
provider_to_class = {
|
|
11
|
+
Provider.OPENAI: OpenAIProvider,
|
|
12
|
+
Provider.VERTEXAI: VertexAIProvider,
|
|
13
|
+
Provider.BEDROCK: BedrockProvider,
|
|
14
|
+
}
|
|
15
|
+
if prov := provider_to_class.get(provider):
|
|
16
|
+
return prov
|
|
17
|
+
|
|
18
|
+
raise ValueError(f"Provider {provider} not found")
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class BaseProvider(ABC):
|
|
6
|
+
"""
|
|
7
|
+
Abstract base class for LLMStudio engine providers.
|
|
8
|
+
|
|
9
|
+
This class defines the core interface for providers that integrate
|
|
10
|
+
with LLMStudio's engine. It is intended to be subclassed by specific
|
|
11
|
+
provider implementations.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self):
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.executor = ThreadPoolExecutor(max_workers=10)
|
|
17
|
+
|
|
18
|
+
async def chat(self, data) -> dict:
|
|
19
|
+
"""
|
|
20
|
+
Asynchronously handle a chat request.
|
|
21
|
+
|
|
22
|
+
Parameters:
|
|
23
|
+
data: The data payload for the chat operation.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
dict: A dictionary containing the chat response.
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
NotImplementedError: This method is intended to be overridden by subclasses.
|
|
30
|
+
"""
|
|
31
|
+
raise NotImplementedError
|
|
32
|
+
|
|
33
|
+
async def test(test, data) -> dict:
|
|
34
|
+
"""
|
|
35
|
+
Asynchronously handle a test request.
|
|
36
|
+
|
|
37
|
+
Parameters:
|
|
38
|
+
data: The data payload for the test operation.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
dict: A dictionary containing the test response.
|
|
42
|
+
|
|
43
|
+
Raises:
|
|
44
|
+
NotImplementedError: This method is intended to be overridden by subclasses.
|
|
45
|
+
"""
|
|
46
|
+
raise NotImplementedError
|
|
47
|
+
|
|
48
|
+
def validate_model_field(self, data, model_list):
|
|
49
|
+
"""
|
|
50
|
+
Validate the 'model_name' field in the request data.
|
|
51
|
+
|
|
52
|
+
Parameters:
|
|
53
|
+
data: The data payload containing the 'model_name'.
|
|
54
|
+
model_list: List of valid model names.
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
HTTPException: If the 'model_name' is not provided or is not in the list of valid models.
|
|
58
|
+
"""
|
|
59
|
+
from fastapi import HTTPException
|
|
60
|
+
|
|
61
|
+
if not data.model_name:
|
|
62
|
+
raise HTTPException(
|
|
63
|
+
status_code=422,
|
|
64
|
+
detail="The parameter 'model_name' is mandatory to be passed in the request body.",
|
|
65
|
+
)
|
|
66
|
+
if data.model_name not in model_list:
|
|
67
|
+
raise HTTPException(
|
|
68
|
+
status_code=422,
|
|
69
|
+
detail=f"The model '{data['model_name']}' does not exist.",
|
|
70
|
+
)
|
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import random
|
|
4
|
+
import time
|
|
5
|
+
from typing import Optional, Tuple
|
|
6
|
+
|
|
7
|
+
import boto3
|
|
8
|
+
from fastapi.responses import StreamingResponse
|
|
9
|
+
from pydantic import BaseModel, Field, validator
|
|
10
|
+
|
|
11
|
+
from llmstudio.engine.config import BedrockConfig
|
|
12
|
+
from llmstudio.engine.constants import BEDROCK_MODELS, CLAUDE_MODELS, END_TOKEN, TITAN_MODELS
|
|
13
|
+
from llmstudio.engine.providers.base_provider import BaseProvider
|
|
14
|
+
from llmstudio.engine.utils import validate_provider_config
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ClaudeParameters(BaseModel):
|
|
18
|
+
"""
|
|
19
|
+
Model for validating and storing parameters specific to Claude model.
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
temperature (Optional[float]): Controls randomness in the model's output.
|
|
23
|
+
max_tokens (Optional[int]): The maximum number of tokens in the output.
|
|
24
|
+
top_p (Optional[float]): Influences the diversity of output by controlling token sampling.
|
|
25
|
+
top_k (Optional[float]): Sets the number of the most likely next tokens to filter for.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
temperature: Optional[float] = Field(1, ge=0, le=1)
|
|
29
|
+
max_tokens: Optional[int] = Field(300, ge=1, le=2048)
|
|
30
|
+
top_p: Optional[float] = Field(0.999, ge=0, le=1)
|
|
31
|
+
top_k: Optional[int] = Field(250, ge=1, le=500)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TitanParameters(BaseModel):
|
|
35
|
+
"""
|
|
36
|
+
Model for validating and storing parameters specific to Titan model.
|
|
37
|
+
|
|
38
|
+
Attributes:
|
|
39
|
+
temperature (Optional[float]): Controls randomness in the model's output.
|
|
40
|
+
max_tokens (Optional[int]): The maximum number of tokens in the output.
|
|
41
|
+
top_p (Optional[float]): Influences the diversity of output by controlling token sampling.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
temperature: Optional[float] = Field(0, ge=0, le=1)
|
|
45
|
+
max_tokens: Optional[int] = Field(512, ge=1, le=4096)
|
|
46
|
+
top_p: Optional[float] = Field(0.9, ge=0.1, le=1)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class BedrockRequest(BaseModel):
|
|
50
|
+
"""
|
|
51
|
+
Represents a request to the Bedrock API.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
api_key (Optional[str]): The API key for authenticating with the Bedrock API.
|
|
55
|
+
api_secret (Optional[str]): The API secret for authenticating with the Bedrock API.
|
|
56
|
+
api_region (Optional[str]): The region where the Bedrock API is hosted.
|
|
57
|
+
model_name (str): The name of the model to be used for the request.
|
|
58
|
+
chat_input (str): The input string for the chat.
|
|
59
|
+
parameters (Optional[BaseModel]): Additional parameters for the model, encapsulated in a BaseModel.
|
|
60
|
+
is_stream (Optional[bool]): Flag to indicate if the request is for streaming. Defaults to False.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
api_key: Optional[str]
|
|
64
|
+
api_secret: Optional[str]
|
|
65
|
+
api_region: Optional[str]
|
|
66
|
+
model_name: str
|
|
67
|
+
chat_input: str
|
|
68
|
+
parameters: Optional[BaseModel]
|
|
69
|
+
is_stream: Optional[bool] = False
|
|
70
|
+
|
|
71
|
+
@validator("parameters", pre=True, always=True)
|
|
72
|
+
def validate_parameters_based_on_model_name(cls, parameters, values):
|
|
73
|
+
"""
|
|
74
|
+
Validate and convert parameters based on the model_name.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
parameters (Dict[str, Any]): Parameters to validate and convert.
|
|
78
|
+
values (Dict[str, Any]): Contains previously validated fields.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
BaseModel: An instance of `TitanParameters` or `ClaudeParameters` based on `model_name`.
|
|
82
|
+
|
|
83
|
+
Raises:
|
|
84
|
+
ValueError if model_name is invalid.
|
|
85
|
+
"""
|
|
86
|
+
model_name = values.get("model_name")
|
|
87
|
+
if model_name in TITAN_MODELS:
|
|
88
|
+
return TitanParameters(**parameters)
|
|
89
|
+
if model_name in CLAUDE_MODELS:
|
|
90
|
+
return ClaudeParameters(**parameters)
|
|
91
|
+
|
|
92
|
+
raise ValueError(f"Invalid model_name: {model_name}")
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class BedrockTest(BaseModel):
|
|
96
|
+
"""
|
|
97
|
+
A Pydantic model for validating Bedrock API requests.
|
|
98
|
+
|
|
99
|
+
Attributes:
|
|
100
|
+
api_key (str): The API key provided by the user for authentication with Bedrock's API.
|
|
101
|
+
api_secret (str): The API secret key provided by the user for authentication.
|
|
102
|
+
api_region (str): The API region for Bedrock API requests.
|
|
103
|
+
model_name (str): The name of the model intended for use with the Bedrock API.
|
|
104
|
+
|
|
105
|
+
Methods:
|
|
106
|
+
validate_model_name: Ensures that `model_name` is one of the allowed values.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
api_key: Optional[str]
|
|
110
|
+
api_secret: Optional[str]
|
|
111
|
+
api_region: Optional[str]
|
|
112
|
+
model_name: str
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class BedrockProvider(BaseProvider):
|
|
116
|
+
"""
|
|
117
|
+
BedrockProvider class to interact with the Bedrock API.
|
|
118
|
+
|
|
119
|
+
Attributes:
|
|
120
|
+
bedrock_config (BedrockConfig): Configuration for the Bedrock API.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
def __init__(self, config: BedrockConfig, api_key: dict):
|
|
124
|
+
"""
|
|
125
|
+
Initialize the BedrockProvider class.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
config (BedrockConfig): Configuration for the Bedrock API.
|
|
129
|
+
api_key (dict): API key required for the Bedrock API.
|
|
130
|
+
|
|
131
|
+
Raises:
|
|
132
|
+
ValidationError: If the provided config and API key are invalid.
|
|
133
|
+
"""
|
|
134
|
+
super().__init__()
|
|
135
|
+
self.bedrock_config = validate_provider_config(config, api_key)
|
|
136
|
+
|
|
137
|
+
async def chat(self, data: BedrockRequest) -> dict:
|
|
138
|
+
"""
|
|
139
|
+
Endpoint to process chat input via Bedrock API and generate a model's response.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
data (BedrockRequest): Validated API request data.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
Union[StreamingResponse, dict]: Streaming response if is_stream is True, otherwise a dict with chat and token data.
|
|
146
|
+
"""
|
|
147
|
+
data = BedrockRequest(**data)
|
|
148
|
+
self.validate_model_field(data, BEDROCK_MODELS)
|
|
149
|
+
loop = asyncio.get_event_loop()
|
|
150
|
+
session = boto3.Session(
|
|
151
|
+
aws_access_key_id=self.bedrock_config["api_key"],
|
|
152
|
+
aws_secret_access_key=self.bedrock_config["api_secret"],
|
|
153
|
+
)
|
|
154
|
+
bedrock = session.client(
|
|
155
|
+
service_name="bedrock", region_name=self.bedrock_config["api_region"]
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
body, response_keys = generate_body_and_response(data)
|
|
159
|
+
|
|
160
|
+
if data.is_stream:
|
|
161
|
+
response = await loop.run_in_executor(
|
|
162
|
+
None,
|
|
163
|
+
lambda: bedrock.invoke_model_with_response_stream(
|
|
164
|
+
body=json.dumps(body),
|
|
165
|
+
modelId=data.model_name,
|
|
166
|
+
accept="application/json",
|
|
167
|
+
contentType="application/json",
|
|
168
|
+
).get("body"),
|
|
169
|
+
)
|
|
170
|
+
return StreamingResponse(generate_stream_response(response, response_keys))
|
|
171
|
+
else:
|
|
172
|
+
response = await loop.run_in_executor(
|
|
173
|
+
None,
|
|
174
|
+
lambda: json.loads(
|
|
175
|
+
bedrock.invoke_model(
|
|
176
|
+
body=json.dumps(body),
|
|
177
|
+
modelId=data.model_name,
|
|
178
|
+
accept="application/json",
|
|
179
|
+
contentType="application/json",
|
|
180
|
+
)
|
|
181
|
+
.get("body")
|
|
182
|
+
.read()
|
|
183
|
+
),
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
response = response["results"][0] if response_keys["use_results"] else response
|
|
187
|
+
|
|
188
|
+
data = {
|
|
189
|
+
"id": random.randint(0, 1000),
|
|
190
|
+
"chatInput": data.chat_input,
|
|
191
|
+
"chatOutput": response[response_keys["output_key"]],
|
|
192
|
+
"inputTokens": response.get(response_keys["input_tokens_key"], 0),
|
|
193
|
+
"outputTokens": response.get(response_keys["output_tokens_key"], 0),
|
|
194
|
+
"totalTokens": response.get(response_keys["input_tokens_key"], 0)
|
|
195
|
+
+ response.get(response_keys["output_tokens_key"], 0),
|
|
196
|
+
"cost": 0, # TODO
|
|
197
|
+
"timestamp": time.time(),
|
|
198
|
+
"modelName": data.model_name,
|
|
199
|
+
"parameters": data.parameters.dict(),
|
|
200
|
+
}
|
|
201
|
+
return data
|
|
202
|
+
|
|
203
|
+
async def test(self, data: BedrockTest) -> bool:
|
|
204
|
+
"""
|
|
205
|
+
Test the validity of the Bedrock API credentials and model name.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
data (BedrockTest): A model instance containing the Bedrock API credentials
|
|
209
|
+
and model name to test.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
bool: `True` if the API credentials and model name are valid, otherwise `False`.
|
|
213
|
+
"""
|
|
214
|
+
data = BedrockTest(**data)
|
|
215
|
+
try:
|
|
216
|
+
session = boto3.Session(
|
|
217
|
+
aws_access_key_id=self.bedrock_config["api_key"],
|
|
218
|
+
aws_secret_access_key=self.bedrock_config["api_secret"],
|
|
219
|
+
)
|
|
220
|
+
bedrock = session.client(
|
|
221
|
+
service_name="bedrock", region_name=self.bedrock_config["api_region"]
|
|
222
|
+
)
|
|
223
|
+
response = bedrock.list_foundation_models()
|
|
224
|
+
|
|
225
|
+
if data.model_name in [i["modelId"] for i in response["modelSummaries"]]:
|
|
226
|
+
return True
|
|
227
|
+
else:
|
|
228
|
+
return False
|
|
229
|
+
except Exception:
|
|
230
|
+
return False
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def generate_body_and_response(data: BedrockProvider) -> Tuple[dict, dict]:
|
|
234
|
+
"""
|
|
235
|
+
Generate request body and response keys based on model name.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
data (BedrockRequest): Validated API request data.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Tuple[dict, dict]: Tuple of request body and response keys.
|
|
242
|
+
|
|
243
|
+
Raises:
|
|
244
|
+
ValueError if model name is invalid.
|
|
245
|
+
"""
|
|
246
|
+
if data.model_name in TITAN_MODELS:
|
|
247
|
+
return {
|
|
248
|
+
"inputText": data.chat_input,
|
|
249
|
+
"textGenerationConfig": {
|
|
250
|
+
"maxTokenCount": data.parameters.max_tokens,
|
|
251
|
+
"temperature": data.parameters.temperature,
|
|
252
|
+
"topP": data.parameters.top_p,
|
|
253
|
+
},
|
|
254
|
+
}, {
|
|
255
|
+
"output_key": "outputText",
|
|
256
|
+
"input_tokens_key": "inputTextTokenCount",
|
|
257
|
+
"output_tokens_key": "tokenCount",
|
|
258
|
+
"use_results": True,
|
|
259
|
+
}
|
|
260
|
+
if data.model_name in CLAUDE_MODELS:
|
|
261
|
+
return {
|
|
262
|
+
"prompt": data.chat_input,
|
|
263
|
+
"max_tokens_to_sample": data.parameters.max_tokens,
|
|
264
|
+
"temperature": data.parameters.temperature,
|
|
265
|
+
"top_k": data.parameters.top_k,
|
|
266
|
+
"top_p": data.parameters.top_p,
|
|
267
|
+
}, {
|
|
268
|
+
"output_key": "completion",
|
|
269
|
+
"input_tokens_key": None,
|
|
270
|
+
"output_tokens_key": None,
|
|
271
|
+
"use_results": False,
|
|
272
|
+
}
|
|
273
|
+
else:
|
|
274
|
+
raise ValueError(f"Invalid model_name: {data.model_name}")
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def get_cost(input_tokens: int, output_tokens: int) -> float:
|
|
278
|
+
"""
|
|
279
|
+
Calculate the cost based on input and output tokens.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
input_tokens (int): Number of tokens in the input.
|
|
283
|
+
output_tokens (int): Number of tokens in the output.
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
float: Cost.
|
|
287
|
+
"""
|
|
288
|
+
return None
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def generate_stream_response(response, response_keys):
|
|
292
|
+
"""
|
|
293
|
+
Generate streaming response based on response events and keys.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
response (Any): Response from the Bedrock API call.
|
|
297
|
+
response_keys (Dict[str, Any]): Keys to extract relevant data from the response.
|
|
298
|
+
|
|
299
|
+
Yields:
|
|
300
|
+
str: Extracted data from response chunks.
|
|
301
|
+
"""
|
|
302
|
+
chat_output = ""
|
|
303
|
+
for event in response:
|
|
304
|
+
chunk = event.get("chunk")
|
|
305
|
+
if chunk:
|
|
306
|
+
chunk_content = json.loads(chunk.get("bytes").decode())[response_keys["output_key"]]
|
|
307
|
+
chat_output += chunk_content
|
|
308
|
+
yield chunk_content
|
|
309
|
+
|
|
310
|
+
input_tokens = 0
|
|
311
|
+
output_tokens = 0
|
|
312
|
+
cost = 0
|
|
313
|
+
yield f"{END_TOKEN},{input_tokens},{output_tokens},{cost}"
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import random
|
|
3
|
+
import time
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import openai
|
|
7
|
+
import tiktoken
|
|
8
|
+
from fastapi.responses import StreamingResponse
|
|
9
|
+
from pydantic import BaseModel, Field
|
|
10
|
+
|
|
11
|
+
from llmstudio.engine.config import OpenAIConfig
|
|
12
|
+
from llmstudio.engine.constants import END_TOKEN, OPENAI_PRICING_DICT
|
|
13
|
+
from llmstudio.engine.providers.base_provider import BaseProvider
|
|
14
|
+
from llmstudio.engine.utils import validate_provider_config
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class OpenAIParameters(BaseModel):
|
|
18
|
+
"""
|
|
19
|
+
A Pydantic model for encapsulating parameters used in OpenAI API requests.
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
temperature (Optional[float]): Controls randomness in the model's output.
|
|
23
|
+
max_tokens (Optional[int]): The maximum number of tokens in the output.
|
|
24
|
+
top_p (Optional[float]): Influences the diversity of output by controlling token sampling.
|
|
25
|
+
frequency_penalty (Optional[float]): Modifies the likelihood of tokens appearing based on their frequency.
|
|
26
|
+
presence_penalty (Optional[float]): Adjusts the likelihood of new tokens appearing.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
temperature: Optional[float] = Field(default=1, ge=0, le=2)
|
|
30
|
+
max_tokens: Optional[int] = Field(default=256, ge=1, le=2048)
|
|
31
|
+
top_p: Optional[float] = Field(default=1, ge=0, le=1)
|
|
32
|
+
frequency_penalty: Optional[float] = Field(default=0, ge=0, le=1)
|
|
33
|
+
presence_penalty: Optional[float] = Field(default=0, ge=0, le=1)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class OpenAIRequest(BaseModel):
|
|
37
|
+
"""
|
|
38
|
+
A Pydantic model that represents a request to an OpenAI API.
|
|
39
|
+
|
|
40
|
+
Attributes:
|
|
41
|
+
api_key (Optional[str]): The API key to use for authenticating the request.
|
|
42
|
+
model_name (str): The name of the language model to query.
|
|
43
|
+
chat_input (str): The input text to send to the model.
|
|
44
|
+
parameters (Optional[OpenAIParameters]): An optional instance of OpenAIParameters to further configure the request.
|
|
45
|
+
is_stream (Optional[bool]): Indicates if the request should be a streaming request; default is False.
|
|
46
|
+
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
api_key: Optional[str]
|
|
50
|
+
model_name: str
|
|
51
|
+
chat_input: str
|
|
52
|
+
parameters: Optional[OpenAIParameters] = OpenAIParameters()
|
|
53
|
+
is_stream: Optional[bool] = False
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class OpenAITest(BaseModel):
|
|
57
|
+
"""
|
|
58
|
+
A Pydantic model for validating OpenAI API requests.
|
|
59
|
+
|
|
60
|
+
Attributes:
|
|
61
|
+
api_key (str): The API key provided by the user authentication with OpenAI API.
|
|
62
|
+
model_name (str): The name of the model to be used for generating text
|
|
63
|
+
|
|
64
|
+
Methods:
|
|
65
|
+
validate_model_name: Ensures that `model_name` is one of the allowed values.
|
|
66
|
+
```
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
api_key: Optional[str]
|
|
70
|
+
model_name: str
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class OpenAIProvider(BaseProvider):
|
|
74
|
+
"""
|
|
75
|
+
A provider class to handle interactions with the OpenAI GPT models.
|
|
76
|
+
|
|
77
|
+
Attributes:
|
|
78
|
+
openai_config (OpenAIConfig): Configuration settings for OpenAI API.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(self, config: OpenAIConfig, api_key: str):
|
|
82
|
+
"""
|
|
83
|
+
Initialize the OpenAIProvider with given config and API key.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
config (OpenAIConfig): Configuration settings for OpenAI API.
|
|
87
|
+
api_key (str): API key for authentication.
|
|
88
|
+
"""
|
|
89
|
+
super().__init__()
|
|
90
|
+
if isinstance(config, OpenAIConfig):
|
|
91
|
+
self.openai_config = config
|
|
92
|
+
else:
|
|
93
|
+
self.openai_config = OpenAIConfig(**validate_provider_config(config, api_key))
|
|
94
|
+
|
|
95
|
+
async def chat(self, data: OpenAIRequest) -> dict:
|
|
96
|
+
"""
|
|
97
|
+
Generate chat-based model completions using OpenAI API.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
data (OpenAIRequest): A model instance containing chat input, model name, and additional parameters.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
dict: A dictionary containing chat input, chat output, tokens information, cost, and other metadata.
|
|
104
|
+
|
|
105
|
+
Raises:
|
|
106
|
+
ValueError: If the specified model field is invalid.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
data = OpenAIRequest(**data)
|
|
110
|
+
|
|
111
|
+
self.validate_model_field(data, OPENAI_PRICING_DICT.keys())
|
|
112
|
+
openai.api_key = self.openai_config.api_key
|
|
113
|
+
|
|
114
|
+
# Asynchronous call, for parallelism
|
|
115
|
+
loop = asyncio.get_event_loop()
|
|
116
|
+
|
|
117
|
+
response = await loop.run_in_executor(
|
|
118
|
+
self.executor,
|
|
119
|
+
lambda: openai.ChatCompletion.create(
|
|
120
|
+
model=data.model_name,
|
|
121
|
+
messages=[
|
|
122
|
+
{
|
|
123
|
+
"role": "user",
|
|
124
|
+
"content": data.chat_input,
|
|
125
|
+
}
|
|
126
|
+
],
|
|
127
|
+
temperature=data.parameters.temperature,
|
|
128
|
+
max_tokens=data.parameters.max_tokens,
|
|
129
|
+
top_p=data.parameters.top_p,
|
|
130
|
+
frequency_penalty=data.parameters.frequency_penalty,
|
|
131
|
+
presence_penalty=data.parameters.presence_penalty,
|
|
132
|
+
stream=data.is_stream,
|
|
133
|
+
),
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
if data.is_stream:
|
|
137
|
+
return StreamingResponse(generate_stream_response(response, data))
|
|
138
|
+
|
|
139
|
+
input_tokens = get_tokens(data.chat_input, data.model_name)
|
|
140
|
+
output_tokens = get_tokens(response["choices"][0]["message"]["content"], data.model_name)
|
|
141
|
+
|
|
142
|
+
data = {
|
|
143
|
+
"id": random.randint(0, 1000),
|
|
144
|
+
"chatInput": data.chat_input,
|
|
145
|
+
"chatOutput": response["choices"][0]["message"]["content"],
|
|
146
|
+
"inputTokens": input_tokens,
|
|
147
|
+
"outputTokens": output_tokens,
|
|
148
|
+
"totalTokens": input_tokens + output_tokens,
|
|
149
|
+
"cost": get_cost(input_tokens, output_tokens, data.model_name),
|
|
150
|
+
"timestamp": time.time(),
|
|
151
|
+
"modelName": data.model_name,
|
|
152
|
+
"parameters": data.parameters.dict(),
|
|
153
|
+
}
|
|
154
|
+
return data
|
|
155
|
+
|
|
156
|
+
async def test(self, data: OpenAITest) -> bool:
|
|
157
|
+
"""
|
|
158
|
+
Test the validity of the OpenAI API key.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
data (OpenAITest): A model instance which includes the API key for OpenAI.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
bool: `True` if the API key is valid and initialization succeeds, otherwise `False`.
|
|
165
|
+
"""
|
|
166
|
+
openai.api_key = self.openai_config.api_key
|
|
167
|
+
data = OpenAITest(**data)
|
|
168
|
+
try:
|
|
169
|
+
self.validate_model_field(data, OPENAI_PRICING_DICT.keys())
|
|
170
|
+
openai.Model.retrieve(data.model_name)
|
|
171
|
+
return True
|
|
172
|
+
except Exception:
|
|
173
|
+
return False
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def get_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
|
|
177
|
+
"""
|
|
178
|
+
Calculate the cost of using the OpenAI API based on token usage and model.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
input_tokens (int): Number of tokens in the input.
|
|
182
|
+
output_tokens (int): Number of tokens in the output.
|
|
183
|
+
model_name (str): Identifier of the model used.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
float: The calculated cost for the API usage.
|
|
187
|
+
"""
|
|
188
|
+
return (
|
|
189
|
+
OPENAI_PRICING_DICT[model_name]["input_tokens"] * input_tokens
|
|
190
|
+
+ OPENAI_PRICING_DICT[model_name]["output_tokens"] * output_tokens
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def get_tokens(chat_input: str, model_name: str) -> int:
|
|
195
|
+
"""
|
|
196
|
+
Determine the number of tokens in a given input string using the specified model’s tokenizer.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
chat_input (str): Text to be tokenized.
|
|
200
|
+
model_name (str): Identifier of the model, determines tokenizer used.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
int: Number of tokens in the input string.
|
|
204
|
+
"""
|
|
205
|
+
tokenizer = tiktoken.encoding_for_model(model_name)
|
|
206
|
+
return len(tokenizer.encode(chat_input))
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def generate_stream_response(response: dict, data: OpenAIProvider):
|
|
210
|
+
"""
|
|
211
|
+
Generate stream responses, yielding chat output or tokens and cost information at stream end.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
response (dict): Dictionary containing chunks of responses from the OpenAI API.
|
|
215
|
+
data (OpenAIRequest): OpenAIRequest object containing necessary parameters for the API call.
|
|
216
|
+
|
|
217
|
+
Yields:
|
|
218
|
+
str: A chunk of chat output or, at stream end, tokens counts and cost information.
|
|
219
|
+
"""
|
|
220
|
+
chat_output = ""
|
|
221
|
+
for chunk in response:
|
|
222
|
+
if (
|
|
223
|
+
chunk["choices"][0]["finish_reason"] != "stop"
|
|
224
|
+
and chunk["choices"][0]["finish_reason"] != "length"
|
|
225
|
+
):
|
|
226
|
+
chunk_content = chunk["choices"][0]["delta"]["content"]
|
|
227
|
+
chat_output += chunk_content
|
|
228
|
+
yield chunk_content
|
|
229
|
+
else:
|
|
230
|
+
input_tokens = get_tokens(data.chat_input, data.model_name)
|
|
231
|
+
output_tokens = get_tokens(chat_output, data.model_name)
|
|
232
|
+
cost = get_cost(input_tokens, output_tokens, data.model_name)
|
|
233
|
+
yield f"{END_TOKEN},{input_tokens},{output_tokens},{cost}" # json
|