llmstudio 0.2.2__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. llmstudio-0.2.3/MANIFEST.in +1 -0
  2. {llmstudio-0.2.2 → llmstudio-0.2.3}/PKG-INFO +1 -1
  3. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/__init__.py +1 -1
  4. llmstudio-0.2.3/llmstudio/engine/providers/__init__.py +18 -0
  5. llmstudio-0.2.3/llmstudio/engine/providers/base_provider.py +70 -0
  6. llmstudio-0.2.3/llmstudio/engine/providers/bedrock.py +313 -0
  7. llmstudio-0.2.3/llmstudio/engine/providers/openai.py +233 -0
  8. llmstudio-0.2.3/llmstudio/engine/providers/vertexai.py +312 -0
  9. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/models/models.py +113 -2
  10. llmstudio-0.2.3/llmstudio/ui/build/asset-manifest.json +19 -0
  11. llmstudio-0.2.3/llmstudio/ui/build/fonts/Satoshi-Black.woff +0 -0
  12. llmstudio-0.2.3/llmstudio/ui/build/fonts/Satoshi-Bold.woff +0 -0
  13. llmstudio-0.2.3/llmstudio/ui/build/fonts/Satoshi-Light.woff +0 -0
  14. llmstudio-0.2.3/llmstudio/ui/build/fonts/Satoshi-Medium.woff +0 -0
  15. llmstudio-0.2.3/llmstudio/ui/build/fonts/Satoshi-Regular.woff +0 -0
  16. llmstudio-0.2.3/llmstudio/ui/build/fonts/VioletSans-Regular.woff +0 -0
  17. llmstudio-0.2.3/llmstudio/ui/build/images/claudio.jpg +0 -0
  18. llmstudio-0.2.3/llmstudio/ui/build/images/icon.png +0 -0
  19. llmstudio-0.2.3/llmstudio/ui/build/index.html +1 -0
  20. llmstudio-0.2.3/llmstudio/ui/build/manifest.json +25 -0
  21. llmstudio-0.2.3/llmstudio/ui/build/robots.txt +3 -0
  22. llmstudio-0.2.3/llmstudio/ui/build/static/css/main.0342ffa4.css +4 -0
  23. llmstudio-0.2.3/llmstudio/ui/build/static/css/main.0342ffa4.css.map +1 -0
  24. llmstudio-0.2.3/llmstudio/ui/build/static/js/main.7337aa4e.js +3 -0
  25. llmstudio-0.2.3/llmstudio/ui/build/static/js/main.7337aa4e.js.LICENSE.txt +83 -0
  26. llmstudio-0.2.3/llmstudio/ui/build/static/js/main.7337aa4e.js.map +1 -0
  27. llmstudio-0.2.3/llmstudio/ui/build/static/media/Satoshi-Black.4261d202e1e9410db1bf.woff +0 -0
  28. llmstudio-0.2.3/llmstudio/ui/build/static/media/Satoshi-Bold.a875ff682ee232938607.woff +0 -0
  29. llmstudio-0.2.3/llmstudio/ui/build/static/media/Satoshi-Light.67e7fa77f107df3491b6.woff +0 -0
  30. llmstudio-0.2.3/llmstudio/ui/build/static/media/Satoshi-Medium.2419b46c96ed15331ba2.woff +0 -0
  31. llmstudio-0.2.3/llmstudio/ui/build/static/media/Satoshi-Regular.ca3da5fd2b609836ef69.woff +0 -0
  32. llmstudio-0.2.3/llmstudio/ui/build/static/media/VioletSans-Regular.425614770e8617faebdd.woff +0 -0
  33. llmstudio-0.2.3/llmstudio/ui/build/svg/ai.svg +10 -0
  34. llmstudio-0.2.3/llmstudio/ui/build/svg/arrow.svg +3 -0
  35. llmstudio-0.2.3/llmstudio/ui/build/svg/home.svg +12 -0
  36. llmstudio-0.2.3/llmstudio/ui/build/svg/load.svg +4 -0
  37. llmstudio-0.2.3/llmstudio/ui/build/svg/magic.svg +10 -0
  38. llmstudio-0.2.3/llmstudio/ui/build/svg/play.svg +11 -0
  39. llmstudio-0.2.3/llmstudio/ui/build/svg/playground.svg +8 -0
  40. llmstudio-0.2.3/llmstudio/ui/build/svg/plus.svg +9 -0
  41. llmstudio-0.2.3/llmstudio/ui/build/svg/prompt.svg +7 -0
  42. llmstudio-0.2.3/llmstudio/ui/build/svg/settings.svg +15 -0
  43. llmstudio-0.2.3/llmstudio/ui/build/svg/sparkles.svg +3 -0
  44. llmstudio-0.2.3/llmstudio/utils/__init__.py +0 -0
  45. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio.egg-info/PKG-INFO +1 -1
  46. llmstudio-0.2.3/llmstudio.egg-info/SOURCES.txt +68 -0
  47. {llmstudio-0.2.2 → llmstudio-0.2.3}/setup.py +6 -18
  48. llmstudio-0.2.2/llmstudio.egg-info/SOURCES.txt +0 -27
  49. {llmstudio-0.2.2 → llmstudio-0.2.3}/LICENSE +0 -0
  50. {llmstudio-0.2.2 → llmstudio-0.2.3}/README.md +0 -0
  51. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/cli.py +0 -0
  52. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/client.py +0 -0
  53. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/engine/__init__.py +0 -0
  54. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/engine/config.py +0 -0
  55. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/engine/constants.py +0 -0
  56. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/engine/utils.py +0 -0
  57. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/models/__init__.py +0 -0
  58. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/models/bedrock.py +0 -0
  59. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/models/openai.py +0 -0
  60. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/models/vertexai.py +0 -0
  61. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/ui/__init__.py +0 -0
  62. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/utils/rest_utils.py +0 -0
  63. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/validators/__init__.py +0 -0
  64. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/validators/bedrock.py +0 -0
  65. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/validators/openai.py +0 -0
  66. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio/validators/vertexai.py +0 -0
  67. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio.egg-info/dependency_links.txt +0 -0
  68. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio.egg-info/entry_points.txt +0 -0
  69. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio.egg-info/requires.txt +0 -0
  70. {llmstudio-0.2.2 → llmstudio-0.2.3}/llmstudio.egg-info/top_level.txt +0 -0
  71. {llmstudio-0.2.2 → llmstudio-0.2.3}/setup.cfg +0 -0
@@ -0,0 +1 @@
1
+ recursive-include llmstudio/ui/build *
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: llmstudio
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: Prompt Perfection at Your Fingertips
5
5
  Home-page: https://llmstudio.ai/
6
6
  Author: TensorOps
@@ -1,5 +1,5 @@
1
1
  name = "version"
2
- __version__ = "0.2.2"
2
+ __version__ = "0.2.3"
3
3
 
4
4
  __requirements__ = [
5
5
  "pydantic",
@@ -0,0 +1,18 @@
1
+ from llmstudio.engine.config import Provider
2
+ from llmstudio.engine.providers.base_provider import BaseProvider
3
+
4
+
5
+ def get_provider(provider: Provider) -> BaseProvider:
6
+ from llmstudio.engine.providers.bedrock import BedrockProvider
7
+ from llmstudio.engine.providers.openai import OpenAIProvider
8
+ from llmstudio.engine.providers.vertexai import VertexAIProvider
9
+
10
+ provider_to_class = {
11
+ Provider.OPENAI: OpenAIProvider,
12
+ Provider.VERTEXAI: VertexAIProvider,
13
+ Provider.BEDROCK: BedrockProvider,
14
+ }
15
+ if prov := provider_to_class.get(provider):
16
+ return prov
17
+
18
+ raise ValueError(f"Provider {provider} not found")
@@ -0,0 +1,70 @@
1
+ from abc import ABC
2
+ from concurrent.futures import ThreadPoolExecutor
3
+
4
+
5
+ class BaseProvider(ABC):
6
+ """
7
+ Abstract base class for LLMStudio engine providers.
8
+
9
+ This class defines the core interface for providers that integrate
10
+ with LLMStudio's engine. It is intended to be subclassed by specific
11
+ provider implementations.
12
+ """
13
+
14
+ def __init__(self):
15
+ super().__init__()
16
+ self.executor = ThreadPoolExecutor(max_workers=10)
17
+
18
+ async def chat(self, data) -> dict:
19
+ """
20
+ Asynchronously handle a chat request.
21
+
22
+ Parameters:
23
+ data: The data payload for the chat operation.
24
+
25
+ Returns:
26
+ dict: A dictionary containing the chat response.
27
+
28
+ Raises:
29
+ NotImplementedError: This method is intended to be overridden by subclasses.
30
+ """
31
+ raise NotImplementedError
32
+
33
+ async def test(test, data) -> dict:
34
+ """
35
+ Asynchronously handle a test request.
36
+
37
+ Parameters:
38
+ data: The data payload for the test operation.
39
+
40
+ Returns:
41
+ dict: A dictionary containing the test response.
42
+
43
+ Raises:
44
+ NotImplementedError: This method is intended to be overridden by subclasses.
45
+ """
46
+ raise NotImplementedError
47
+
48
+ def validate_model_field(self, data, model_list):
49
+ """
50
+ Validate the 'model_name' field in the request data.
51
+
52
+ Parameters:
53
+ data: The data payload containing the 'model_name'.
54
+ model_list: List of valid model names.
55
+
56
+ Raises:
57
+ HTTPException: If the 'model_name' is not provided or is not in the list of valid models.
58
+ """
59
+ from fastapi import HTTPException
60
+
61
+ if not data.model_name:
62
+ raise HTTPException(
63
+ status_code=422,
64
+ detail="The parameter 'model_name' is mandatory to be passed in the request body.",
65
+ )
66
+ if data.model_name not in model_list:
67
+ raise HTTPException(
68
+ status_code=422,
69
+ detail=f"The model '{data['model_name']}' does not exist.",
70
+ )
@@ -0,0 +1,313 @@
1
+ import asyncio
2
+ import json
3
+ import random
4
+ import time
5
+ from typing import Optional, Tuple
6
+
7
+ import boto3
8
+ from fastapi.responses import StreamingResponse
9
+ from pydantic import BaseModel, Field, validator
10
+
11
+ from llmstudio.engine.config import BedrockConfig
12
+ from llmstudio.engine.constants import BEDROCK_MODELS, CLAUDE_MODELS, END_TOKEN, TITAN_MODELS
13
+ from llmstudio.engine.providers.base_provider import BaseProvider
14
+ from llmstudio.engine.utils import validate_provider_config
15
+
16
+
17
+ class ClaudeParameters(BaseModel):
18
+ """
19
+ Model for validating and storing parameters specific to Claude model.
20
+
21
+ Attributes:
22
+ temperature (Optional[float]): Controls randomness in the model's output.
23
+ max_tokens (Optional[int]): The maximum number of tokens in the output.
24
+ top_p (Optional[float]): Influences the diversity of output by controlling token sampling.
25
+ top_k (Optional[float]): Sets the number of the most likely next tokens to filter for.
26
+ """
27
+
28
+ temperature: Optional[float] = Field(1, ge=0, le=1)
29
+ max_tokens: Optional[int] = Field(300, ge=1, le=2048)
30
+ top_p: Optional[float] = Field(0.999, ge=0, le=1)
31
+ top_k: Optional[int] = Field(250, ge=1, le=500)
32
+
33
+
34
+ class TitanParameters(BaseModel):
35
+ """
36
+ Model for validating and storing parameters specific to Titan model.
37
+
38
+ Attributes:
39
+ temperature (Optional[float]): Controls randomness in the model's output.
40
+ max_tokens (Optional[int]): The maximum number of tokens in the output.
41
+ top_p (Optional[float]): Influences the diversity of output by controlling token sampling.
42
+ """
43
+
44
+ temperature: Optional[float] = Field(0, ge=0, le=1)
45
+ max_tokens: Optional[int] = Field(512, ge=1, le=4096)
46
+ top_p: Optional[float] = Field(0.9, ge=0.1, le=1)
47
+
48
+
49
+ class BedrockRequest(BaseModel):
50
+ """
51
+ Represents a request to the Bedrock API.
52
+
53
+ Attributes:
54
+ api_key (Optional[str]): The API key for authenticating with the Bedrock API.
55
+ api_secret (Optional[str]): The API secret for authenticating with the Bedrock API.
56
+ api_region (Optional[str]): The region where the Bedrock API is hosted.
57
+ model_name (str): The name of the model to be used for the request.
58
+ chat_input (str): The input string for the chat.
59
+ parameters (Optional[BaseModel]): Additional parameters for the model, encapsulated in a BaseModel.
60
+ is_stream (Optional[bool]): Flag to indicate if the request is for streaming. Defaults to False.
61
+ """
62
+
63
+ api_key: Optional[str]
64
+ api_secret: Optional[str]
65
+ api_region: Optional[str]
66
+ model_name: str
67
+ chat_input: str
68
+ parameters: Optional[BaseModel]
69
+ is_stream: Optional[bool] = False
70
+
71
+ @validator("parameters", pre=True, always=True)
72
+ def validate_parameters_based_on_model_name(cls, parameters, values):
73
+ """
74
+ Validate and convert parameters based on the model_name.
75
+
76
+ Args:
77
+ parameters (Dict[str, Any]): Parameters to validate and convert.
78
+ values (Dict[str, Any]): Contains previously validated fields.
79
+
80
+ Returns:
81
+ BaseModel: An instance of `TitanParameters` or `ClaudeParameters` based on `model_name`.
82
+
83
+ Raises:
84
+ ValueError if model_name is invalid.
85
+ """
86
+ model_name = values.get("model_name")
87
+ if model_name in TITAN_MODELS:
88
+ return TitanParameters(**parameters)
89
+ if model_name in CLAUDE_MODELS:
90
+ return ClaudeParameters(**parameters)
91
+
92
+ raise ValueError(f"Invalid model_name: {model_name}")
93
+
94
+
95
+ class BedrockTest(BaseModel):
96
+ """
97
+ A Pydantic model for validating Bedrock API requests.
98
+
99
+ Attributes:
100
+ api_key (str): The API key provided by the user for authentication with Bedrock's API.
101
+ api_secret (str): The API secret key provided by the user for authentication.
102
+ api_region (str): The API region for Bedrock API requests.
103
+ model_name (str): The name of the model intended for use with the Bedrock API.
104
+
105
+ Methods:
106
+ validate_model_name: Ensures that `model_name` is one of the allowed values.
107
+ """
108
+
109
+ api_key: Optional[str]
110
+ api_secret: Optional[str]
111
+ api_region: Optional[str]
112
+ model_name: str
113
+
114
+
115
+ class BedrockProvider(BaseProvider):
116
+ """
117
+ BedrockProvider class to interact with the Bedrock API.
118
+
119
+ Attributes:
120
+ bedrock_config (BedrockConfig): Configuration for the Bedrock API.
121
+ """
122
+
123
+ def __init__(self, config: BedrockConfig, api_key: dict):
124
+ """
125
+ Initialize the BedrockProvider class.
126
+
127
+ Args:
128
+ config (BedrockConfig): Configuration for the Bedrock API.
129
+ api_key (dict): API key required for the Bedrock API.
130
+
131
+ Raises:
132
+ ValidationError: If the provided config and API key are invalid.
133
+ """
134
+ super().__init__()
135
+ self.bedrock_config = validate_provider_config(config, api_key)
136
+
137
+ async def chat(self, data: BedrockRequest) -> dict:
138
+ """
139
+ Endpoint to process chat input via Bedrock API and generate a model's response.
140
+
141
+ Args:
142
+ data (BedrockRequest): Validated API request data.
143
+
144
+ Returns:
145
+ Union[StreamingResponse, dict]: Streaming response if is_stream is True, otherwise a dict with chat and token data.
146
+ """
147
+ data = BedrockRequest(**data)
148
+ self.validate_model_field(data, BEDROCK_MODELS)
149
+ loop = asyncio.get_event_loop()
150
+ session = boto3.Session(
151
+ aws_access_key_id=self.bedrock_config["api_key"],
152
+ aws_secret_access_key=self.bedrock_config["api_secret"],
153
+ )
154
+ bedrock = session.client(
155
+ service_name="bedrock", region_name=self.bedrock_config["api_region"]
156
+ )
157
+
158
+ body, response_keys = generate_body_and_response(data)
159
+
160
+ if data.is_stream:
161
+ response = await loop.run_in_executor(
162
+ None,
163
+ lambda: bedrock.invoke_model_with_response_stream(
164
+ body=json.dumps(body),
165
+ modelId=data.model_name,
166
+ accept="application/json",
167
+ contentType="application/json",
168
+ ).get("body"),
169
+ )
170
+ return StreamingResponse(generate_stream_response(response, response_keys))
171
+ else:
172
+ response = await loop.run_in_executor(
173
+ None,
174
+ lambda: json.loads(
175
+ bedrock.invoke_model(
176
+ body=json.dumps(body),
177
+ modelId=data.model_name,
178
+ accept="application/json",
179
+ contentType="application/json",
180
+ )
181
+ .get("body")
182
+ .read()
183
+ ),
184
+ )
185
+
186
+ response = response["results"][0] if response_keys["use_results"] else response
187
+
188
+ data = {
189
+ "id": random.randint(0, 1000),
190
+ "chatInput": data.chat_input,
191
+ "chatOutput": response[response_keys["output_key"]],
192
+ "inputTokens": response.get(response_keys["input_tokens_key"], 0),
193
+ "outputTokens": response.get(response_keys["output_tokens_key"], 0),
194
+ "totalTokens": response.get(response_keys["input_tokens_key"], 0)
195
+ + response.get(response_keys["output_tokens_key"], 0),
196
+ "cost": 0, # TODO
197
+ "timestamp": time.time(),
198
+ "modelName": data.model_name,
199
+ "parameters": data.parameters.dict(),
200
+ }
201
+ return data
202
+
203
+ async def test(self, data: BedrockTest) -> bool:
204
+ """
205
+ Test the validity of the Bedrock API credentials and model name.
206
+
207
+ Args:
208
+ data (BedrockTest): A model instance containing the Bedrock API credentials
209
+ and model name to test.
210
+
211
+ Returns:
212
+ bool: `True` if the API credentials and model name are valid, otherwise `False`.
213
+ """
214
+ data = BedrockTest(**data)
215
+ try:
216
+ session = boto3.Session(
217
+ aws_access_key_id=self.bedrock_config["api_key"],
218
+ aws_secret_access_key=self.bedrock_config["api_secret"],
219
+ )
220
+ bedrock = session.client(
221
+ service_name="bedrock", region_name=self.bedrock_config["api_region"]
222
+ )
223
+ response = bedrock.list_foundation_models()
224
+
225
+ if data.model_name in [i["modelId"] for i in response["modelSummaries"]]:
226
+ return True
227
+ else:
228
+ return False
229
+ except Exception:
230
+ return False
231
+
232
+
233
+ def generate_body_and_response(data: BedrockProvider) -> Tuple[dict, dict]:
234
+ """
235
+ Generate request body and response keys based on model name.
236
+
237
+ Args:
238
+ data (BedrockRequest): Validated API request data.
239
+
240
+ Returns:
241
+ Tuple[dict, dict]: Tuple of request body and response keys.
242
+
243
+ Raises:
244
+ ValueError if model name is invalid.
245
+ """
246
+ if data.model_name in TITAN_MODELS:
247
+ return {
248
+ "inputText": data.chat_input,
249
+ "textGenerationConfig": {
250
+ "maxTokenCount": data.parameters.max_tokens,
251
+ "temperature": data.parameters.temperature,
252
+ "topP": data.parameters.top_p,
253
+ },
254
+ }, {
255
+ "output_key": "outputText",
256
+ "input_tokens_key": "inputTextTokenCount",
257
+ "output_tokens_key": "tokenCount",
258
+ "use_results": True,
259
+ }
260
+ if data.model_name in CLAUDE_MODELS:
261
+ return {
262
+ "prompt": data.chat_input,
263
+ "max_tokens_to_sample": data.parameters.max_tokens,
264
+ "temperature": data.parameters.temperature,
265
+ "top_k": data.parameters.top_k,
266
+ "top_p": data.parameters.top_p,
267
+ }, {
268
+ "output_key": "completion",
269
+ "input_tokens_key": None,
270
+ "output_tokens_key": None,
271
+ "use_results": False,
272
+ }
273
+ else:
274
+ raise ValueError(f"Invalid model_name: {data.model_name}")
275
+
276
+
277
+ def get_cost(input_tokens: int, output_tokens: int) -> float:
278
+ """
279
+ Calculate the cost based on input and output tokens.
280
+
281
+ Args:
282
+ input_tokens (int): Number of tokens in the input.
283
+ output_tokens (int): Number of tokens in the output.
284
+
285
+ Returns:
286
+ float: Cost.
287
+ """
288
+ return None
289
+
290
+
291
+ def generate_stream_response(response, response_keys):
292
+ """
293
+ Generate streaming response based on response events and keys.
294
+
295
+ Args:
296
+ response (Any): Response from the Bedrock API call.
297
+ response_keys (Dict[str, Any]): Keys to extract relevant data from the response.
298
+
299
+ Yields:
300
+ str: Extracted data from response chunks.
301
+ """
302
+ chat_output = ""
303
+ for event in response:
304
+ chunk = event.get("chunk")
305
+ if chunk:
306
+ chunk_content = json.loads(chunk.get("bytes").decode())[response_keys["output_key"]]
307
+ chat_output += chunk_content
308
+ yield chunk_content
309
+
310
+ input_tokens = 0
311
+ output_tokens = 0
312
+ cost = 0
313
+ yield f"{END_TOKEN},{input_tokens},{output_tokens},{cost}"
@@ -0,0 +1,233 @@
1
+ import asyncio
2
+ import random
3
+ import time
4
+ from typing import Optional
5
+
6
+ import openai
7
+ import tiktoken
8
+ from fastapi.responses import StreamingResponse
9
+ from pydantic import BaseModel, Field
10
+
11
+ from llmstudio.engine.config import OpenAIConfig
12
+ from llmstudio.engine.constants import END_TOKEN, OPENAI_PRICING_DICT
13
+ from llmstudio.engine.providers.base_provider import BaseProvider
14
+ from llmstudio.engine.utils import validate_provider_config
15
+
16
+
17
+ class OpenAIParameters(BaseModel):
18
+ """
19
+ A Pydantic model for encapsulating parameters used in OpenAI API requests.
20
+
21
+ Attributes:
22
+ temperature (Optional[float]): Controls randomness in the model's output.
23
+ max_tokens (Optional[int]): The maximum number of tokens in the output.
24
+ top_p (Optional[float]): Influences the diversity of output by controlling token sampling.
25
+ frequency_penalty (Optional[float]): Modifies the likelihood of tokens appearing based on their frequency.
26
+ presence_penalty (Optional[float]): Adjusts the likelihood of new tokens appearing.
27
+ """
28
+
29
+ temperature: Optional[float] = Field(default=1, ge=0, le=2)
30
+ max_tokens: Optional[int] = Field(default=256, ge=1, le=2048)
31
+ top_p: Optional[float] = Field(default=1, ge=0, le=1)
32
+ frequency_penalty: Optional[float] = Field(default=0, ge=0, le=1)
33
+ presence_penalty: Optional[float] = Field(default=0, ge=0, le=1)
34
+
35
+
36
+ class OpenAIRequest(BaseModel):
37
+ """
38
+ A Pydantic model that represents a request to an OpenAI API.
39
+
40
+ Attributes:
41
+ api_key (Optional[str]): The API key to use for authenticating the request.
42
+ model_name (str): The name of the language model to query.
43
+ chat_input (str): The input text to send to the model.
44
+ parameters (Optional[OpenAIParameters]): An optional instance of OpenAIParameters to further configure the request.
45
+ is_stream (Optional[bool]): Indicates if the request should be a streaming request; default is False.
46
+
47
+ """
48
+
49
+ api_key: Optional[str]
50
+ model_name: str
51
+ chat_input: str
52
+ parameters: Optional[OpenAIParameters] = OpenAIParameters()
53
+ is_stream: Optional[bool] = False
54
+
55
+
56
+ class OpenAITest(BaseModel):
57
+ """
58
+ A Pydantic model for validating OpenAI API requests.
59
+
60
+ Attributes:
61
+ api_key (str): The API key provided by the user authentication with OpenAI API.
62
+ model_name (str): The name of the model to be used for generating text
63
+
64
+ Methods:
65
+ validate_model_name: Ensures that `model_name` is one of the allowed values.
66
+ ```
67
+ """
68
+
69
+ api_key: Optional[str]
70
+ model_name: str
71
+
72
+
73
+ class OpenAIProvider(BaseProvider):
74
+ """
75
+ A provider class to handle interactions with the OpenAI GPT models.
76
+
77
+ Attributes:
78
+ openai_config (OpenAIConfig): Configuration settings for OpenAI API.
79
+ """
80
+
81
+ def __init__(self, config: OpenAIConfig, api_key: str):
82
+ """
83
+ Initialize the OpenAIProvider with given config and API key.
84
+
85
+ Args:
86
+ config (OpenAIConfig): Configuration settings for OpenAI API.
87
+ api_key (str): API key for authentication.
88
+ """
89
+ super().__init__()
90
+ if isinstance(config, OpenAIConfig):
91
+ self.openai_config = config
92
+ else:
93
+ self.openai_config = OpenAIConfig(**validate_provider_config(config, api_key))
94
+
95
+ async def chat(self, data: OpenAIRequest) -> dict:
96
+ """
97
+ Generate chat-based model completions using OpenAI API.
98
+
99
+ Args:
100
+ data (OpenAIRequest): A model instance containing chat input, model name, and additional parameters.
101
+
102
+ Returns:
103
+ dict: A dictionary containing chat input, chat output, tokens information, cost, and other metadata.
104
+
105
+ Raises:
106
+ ValueError: If the specified model field is invalid.
107
+ """
108
+
109
+ data = OpenAIRequest(**data)
110
+
111
+ self.validate_model_field(data, OPENAI_PRICING_DICT.keys())
112
+ openai.api_key = self.openai_config.api_key
113
+
114
+ # Asynchronous call, for parallelism
115
+ loop = asyncio.get_event_loop()
116
+
117
+ response = await loop.run_in_executor(
118
+ self.executor,
119
+ lambda: openai.ChatCompletion.create(
120
+ model=data.model_name,
121
+ messages=[
122
+ {
123
+ "role": "user",
124
+ "content": data.chat_input,
125
+ }
126
+ ],
127
+ temperature=data.parameters.temperature,
128
+ max_tokens=data.parameters.max_tokens,
129
+ top_p=data.parameters.top_p,
130
+ frequency_penalty=data.parameters.frequency_penalty,
131
+ presence_penalty=data.parameters.presence_penalty,
132
+ stream=data.is_stream,
133
+ ),
134
+ )
135
+
136
+ if data.is_stream:
137
+ return StreamingResponse(generate_stream_response(response, data))
138
+
139
+ input_tokens = get_tokens(data.chat_input, data.model_name)
140
+ output_tokens = get_tokens(response["choices"][0]["message"]["content"], data.model_name)
141
+
142
+ data = {
143
+ "id": random.randint(0, 1000),
144
+ "chatInput": data.chat_input,
145
+ "chatOutput": response["choices"][0]["message"]["content"],
146
+ "inputTokens": input_tokens,
147
+ "outputTokens": output_tokens,
148
+ "totalTokens": input_tokens + output_tokens,
149
+ "cost": get_cost(input_tokens, output_tokens, data.model_name),
150
+ "timestamp": time.time(),
151
+ "modelName": data.model_name,
152
+ "parameters": data.parameters.dict(),
153
+ }
154
+ return data
155
+
156
+ async def test(self, data: OpenAITest) -> bool:
157
+ """
158
+ Test the validity of the OpenAI API key.
159
+
160
+ Args:
161
+ data (OpenAITest): A model instance which includes the API key for OpenAI.
162
+
163
+ Returns:
164
+ bool: `True` if the API key is valid and initialization succeeds, otherwise `False`.
165
+ """
166
+ openai.api_key = self.openai_config.api_key
167
+ data = OpenAITest(**data)
168
+ try:
169
+ self.validate_model_field(data, OPENAI_PRICING_DICT.keys())
170
+ openai.Model.retrieve(data.model_name)
171
+ return True
172
+ except Exception:
173
+ return False
174
+
175
+
176
+ def get_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
177
+ """
178
+ Calculate the cost of using the OpenAI API based on token usage and model.
179
+
180
+ Args:
181
+ input_tokens (int): Number of tokens in the input.
182
+ output_tokens (int): Number of tokens in the output.
183
+ model_name (str): Identifier of the model used.
184
+
185
+ Returns:
186
+ float: The calculated cost for the API usage.
187
+ """
188
+ return (
189
+ OPENAI_PRICING_DICT[model_name]["input_tokens"] * input_tokens
190
+ + OPENAI_PRICING_DICT[model_name]["output_tokens"] * output_tokens
191
+ )
192
+
193
+
194
+ def get_tokens(chat_input: str, model_name: str) -> int:
195
+ """
196
+ Determine the number of tokens in a given input string using the specified model’s tokenizer.
197
+
198
+ Args:
199
+ chat_input (str): Text to be tokenized.
200
+ model_name (str): Identifier of the model, determines tokenizer used.
201
+
202
+ Returns:
203
+ int: Number of tokens in the input string.
204
+ """
205
+ tokenizer = tiktoken.encoding_for_model(model_name)
206
+ return len(tokenizer.encode(chat_input))
207
+
208
+
209
+ def generate_stream_response(response: dict, data: OpenAIProvider):
210
+ """
211
+ Generate stream responses, yielding chat output or tokens and cost information at stream end.
212
+
213
+ Args:
214
+ response (dict): Dictionary containing chunks of responses from the OpenAI API.
215
+ data (OpenAIRequest): OpenAIRequest object containing necessary parameters for the API call.
216
+
217
+ Yields:
218
+ str: A chunk of chat output or, at stream end, tokens counts and cost information.
219
+ """
220
+ chat_output = ""
221
+ for chunk in response:
222
+ if (
223
+ chunk["choices"][0]["finish_reason"] != "stop"
224
+ and chunk["choices"][0]["finish_reason"] != "length"
225
+ ):
226
+ chunk_content = chunk["choices"][0]["delta"]["content"]
227
+ chat_output += chunk_content
228
+ yield chunk_content
229
+ else:
230
+ input_tokens = get_tokens(data.chat_input, data.model_name)
231
+ output_tokens = get_tokens(chat_output, data.model_name)
232
+ cost = get_cost(input_tokens, output_tokens, data.model_name)
233
+ yield f"{END_TOKEN},{input_tokens},{output_tokens},{cost}" # json