guidellm 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of guidellm might be problematic. Click here for more details.
- guidellm/__init__.py +19 -0
- guidellm/backend/__init__.py +10 -0
- guidellm/backend/base.py +320 -0
- guidellm/backend/openai.py +168 -0
- guidellm/config.py +234 -0
- guidellm/core/__init__.py +24 -0
- guidellm/core/distribution.py +190 -0
- guidellm/core/report.py +321 -0
- guidellm/core/request.py +44 -0
- guidellm/core/result.py +545 -0
- guidellm/core/serializable.py +169 -0
- guidellm/executor/__init__.py +10 -0
- guidellm/executor/base.py +213 -0
- guidellm/executor/profile_generator.py +343 -0
- guidellm/logger.py +83 -0
- guidellm/main.py +336 -0
- guidellm/request/__init__.py +13 -0
- guidellm/request/base.py +194 -0
- guidellm/request/emulated.py +391 -0
- guidellm/request/file.py +76 -0
- guidellm/request/transformers.py +100 -0
- guidellm/scheduler/__init__.py +4 -0
- guidellm/scheduler/base.py +374 -0
- guidellm/scheduler/load_generator.py +196 -0
- guidellm/utils/__init__.py +40 -0
- guidellm/utils/injector.py +70 -0
- guidellm/utils/progress.py +196 -0
- guidellm/utils/text.py +455 -0
- guidellm/utils/transformers.py +151 -0
- guidellm-0.1.0.dist-info/LICENSE +201 -0
- guidellm-0.1.0.dist-info/METADATA +434 -0
- guidellm-0.1.0.dist-info/RECORD +35 -0
- guidellm-0.1.0.dist-info/WHEEL +5 -0
- guidellm-0.1.0.dist-info/entry_points.txt +3 -0
- guidellm-0.1.0.dist-info/top_level.txt +1 -0
guidellm/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Guidellm is a package that provides an easy and intuitive interface for
|
|
3
|
+
evaluating and benchmarking large language models (LLMs).
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
# flake8: noqa
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import transformers # type: ignore
|
|
10
|
+
|
|
11
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Silence warnings for tokenizers
|
|
12
|
+
transformers.logging.set_verbosity_error() # Silence warnings for transformers
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
from .config import settings
|
|
16
|
+
from .logger import configure_logger, logger
|
|
17
|
+
from .main import generate_benchmark_report
|
|
18
|
+
|
|
19
|
+
__all__ = ["configure_logger", "logger", "settings", "generate_benchmark_report"]
|
guidellm/backend/base.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import functools
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import AsyncGenerator, Dict, List, Literal, Optional, Type, Union
|
|
5
|
+
|
|
6
|
+
from loguru import logger
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
from transformers import ( # type: ignore # noqa: PGH003
|
|
9
|
+
AutoTokenizer,
|
|
10
|
+
PreTrainedTokenizer,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from guidellm.core import TextGenerationRequest, TextGenerationResult
|
|
14
|
+
|
|
15
|
+
__all__ = ["Backend", "BackendEngine", "BackendEnginePublic", "GenerativeResponse"]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
BackendEnginePublic = Literal["openai_server"]
|
|
19
|
+
BackendEngine = Union[BackendEnginePublic, Literal["test"]]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class GenerativeResponse(BaseModel):
|
|
23
|
+
"""
|
|
24
|
+
A model representing a response from a generative AI backend.
|
|
25
|
+
|
|
26
|
+
:param type_: The type of response, either 'token_iter' for intermediate
|
|
27
|
+
token output or 'final' for the final result.
|
|
28
|
+
:type type_: Literal["token_iter", "final"]
|
|
29
|
+
:param add_token: The token to add to the output
|
|
30
|
+
(only applicable if type_ is 'token_iter').
|
|
31
|
+
:type add_token: Optional[str]
|
|
32
|
+
:param prompt: The original prompt sent to the backend.
|
|
33
|
+
:type prompt: Optional[str]
|
|
34
|
+
:param output: The final generated output (only applicable if type_ is 'final').
|
|
35
|
+
:type output: Optional[str]
|
|
36
|
+
:param prompt_token_count: The number of tokens in the prompt.
|
|
37
|
+
:type prompt_token_count: Optional[int]
|
|
38
|
+
:param output_token_count: The number of tokens in the output.
|
|
39
|
+
:type output_token_count: Optional[int]
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
type_: Literal["token_iter", "final"]
|
|
43
|
+
add_token: Optional[str] = None
|
|
44
|
+
prompt: Optional[str] = None
|
|
45
|
+
output: Optional[str] = None
|
|
46
|
+
prompt_token_count: Optional[int] = None
|
|
47
|
+
output_token_count: Optional[int] = None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class Backend(ABC):
|
|
51
|
+
"""
|
|
52
|
+
Abstract base class for generative AI backends.
|
|
53
|
+
|
|
54
|
+
This class provides a common interface for creating and interacting with different
|
|
55
|
+
generative AI backends. Subclasses should implement the abstract methods to
|
|
56
|
+
define specific backend behavior.
|
|
57
|
+
|
|
58
|
+
:cvar _registry: A dictionary that maps BackendEngine types to backend classes.
|
|
59
|
+
:type _registry: Dict[BackendEngine, Type[Backend]]
|
|
60
|
+
:param type_: The type of the backend.
|
|
61
|
+
:type type_: BackendEngine
|
|
62
|
+
:param target: The target URL for the backend.
|
|
63
|
+
:type target: str
|
|
64
|
+
:param model: The model used by the backend.
|
|
65
|
+
:type model: str
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
_registry: Dict[BackendEngine, "Type[Backend]"] = {}
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def register(cls, backend_type: BackendEngine):
|
|
72
|
+
"""
|
|
73
|
+
A decorator to register a backend class in the backend registry.
|
|
74
|
+
|
|
75
|
+
:param backend_type: The type of backend to register.
|
|
76
|
+
:type backend_type: BackendEngine
|
|
77
|
+
:return: The decorated backend class.
|
|
78
|
+
:rtype: Type[Backend]
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def inner_wrapper(wrapped_class: Type["Backend"]):
|
|
82
|
+
cls._registry[backend_type] = wrapped_class
|
|
83
|
+
logger.info("Registered backend type: {}", backend_type)
|
|
84
|
+
return wrapped_class
|
|
85
|
+
|
|
86
|
+
return inner_wrapper
|
|
87
|
+
|
|
88
|
+
@classmethod
|
|
89
|
+
def create(cls, backend_type: BackendEngine, **kwargs) -> "Backend":
|
|
90
|
+
"""
|
|
91
|
+
Factory method to create a backend instance based on the backend type.
|
|
92
|
+
|
|
93
|
+
:param backend_type: The type of backend to create.
|
|
94
|
+
:type backend_type: BackendEngine
|
|
95
|
+
:param kwargs: Additional arguments for backend initialization.
|
|
96
|
+
:return: An instance of a subclass of Backend.
|
|
97
|
+
:rtype: Backend
|
|
98
|
+
:raises ValueError: If the backend type is not registered.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
logger.info("Creating backend of type {}", backend_type)
|
|
102
|
+
|
|
103
|
+
if backend_type not in cls._registry:
|
|
104
|
+
err = ValueError(f"Unsupported backend type: {backend_type}")
|
|
105
|
+
logger.error("{}", err)
|
|
106
|
+
raise err
|
|
107
|
+
|
|
108
|
+
return Backend._registry[backend_type](**kwargs)
|
|
109
|
+
|
|
110
|
+
def __init__(self, type_: BackendEngine, target: str, model: str):
|
|
111
|
+
"""
|
|
112
|
+
Base constructor for the Backend class.
|
|
113
|
+
Calls into test_connection to ensure the backend is reachable.
|
|
114
|
+
Ensure all setup is done in the subclass constructor before calling super.
|
|
115
|
+
|
|
116
|
+
:param type_: The type of the backend.
|
|
117
|
+
:param target: The target URL for the backend.
|
|
118
|
+
:param model: The model used by the backend.
|
|
119
|
+
"""
|
|
120
|
+
self._type = type_
|
|
121
|
+
self._target = target
|
|
122
|
+
self._model = model
|
|
123
|
+
|
|
124
|
+
self.test_connection()
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def default_model(self) -> str:
|
|
128
|
+
"""
|
|
129
|
+
Get the default model for the backend.
|
|
130
|
+
|
|
131
|
+
:return: The default model.
|
|
132
|
+
:rtype: str
|
|
133
|
+
:raises ValueError: If no models are available.
|
|
134
|
+
"""
|
|
135
|
+
return _cachable_default_model(self)
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def type_(self) -> BackendEngine:
|
|
139
|
+
"""
|
|
140
|
+
Get the type of the backend.
|
|
141
|
+
|
|
142
|
+
:return: The type of the backend.
|
|
143
|
+
:rtype: BackendEngine
|
|
144
|
+
"""
|
|
145
|
+
return self._type
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def target(self) -> str:
|
|
149
|
+
"""
|
|
150
|
+
Get the target URL for the backend.
|
|
151
|
+
|
|
152
|
+
:return: The target URL.
|
|
153
|
+
:rtype: str
|
|
154
|
+
"""
|
|
155
|
+
return self._target
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def model(self) -> str:
|
|
159
|
+
"""
|
|
160
|
+
Get the model used by the backend.
|
|
161
|
+
|
|
162
|
+
:return: The model name.
|
|
163
|
+
:rtype: str
|
|
164
|
+
"""
|
|
165
|
+
return self._model
|
|
166
|
+
|
|
167
|
+
def model_tokenizer(self) -> PreTrainedTokenizer:
|
|
168
|
+
"""
|
|
169
|
+
Get the tokenizer for the backend model.
|
|
170
|
+
|
|
171
|
+
:return: The tokenizer instance.
|
|
172
|
+
"""
|
|
173
|
+
return AutoTokenizer.from_pretrained(self.model)
|
|
174
|
+
|
|
175
|
+
def test_connection(self) -> bool:
|
|
176
|
+
"""
|
|
177
|
+
Test the connection to the backend by running a short text generation request.
|
|
178
|
+
If successful, returns True, otherwise raises an exception.
|
|
179
|
+
|
|
180
|
+
:return: True if the connection is successful.
|
|
181
|
+
:rtype: bool
|
|
182
|
+
:raises ValueError: If the connection test fails.
|
|
183
|
+
"""
|
|
184
|
+
try:
|
|
185
|
+
asyncio.get_running_loop()
|
|
186
|
+
is_async = True
|
|
187
|
+
except RuntimeError:
|
|
188
|
+
is_async = False
|
|
189
|
+
|
|
190
|
+
if is_async:
|
|
191
|
+
logger.warning("Running in async mode, cannot test connection")
|
|
192
|
+
return True
|
|
193
|
+
|
|
194
|
+
try:
|
|
195
|
+
request = TextGenerationRequest(
|
|
196
|
+
prompt="Test connection", output_token_count=5
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
asyncio.run(self.submit(request))
|
|
200
|
+
return True
|
|
201
|
+
except Exception as err:
|
|
202
|
+
raise_err = RuntimeError(
|
|
203
|
+
f"Backend connection test failed for backend type={self.type_} "
|
|
204
|
+
f"with target={self.target} and model={self.model} with error: {err}"
|
|
205
|
+
)
|
|
206
|
+
logger.error(raise_err)
|
|
207
|
+
raise raise_err from err
|
|
208
|
+
|
|
209
|
+
async def submit(self, request: TextGenerationRequest) -> TextGenerationResult:
|
|
210
|
+
"""
|
|
211
|
+
Submit a text generation request and return the result.
|
|
212
|
+
|
|
213
|
+
This method handles the request submission to the backend and processes
|
|
214
|
+
the response in a streaming fashion if applicable.
|
|
215
|
+
|
|
216
|
+
:param request: The request object containing the prompt
|
|
217
|
+
and other configurations.
|
|
218
|
+
:type request: TextGenerationRequest
|
|
219
|
+
:return: The result of the text generation request.
|
|
220
|
+
:rtype: TextGenerationResult
|
|
221
|
+
:raises ValueError: If no response is received from the backend.
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
logger.debug("Submitting request with prompt: {}", request.prompt)
|
|
225
|
+
|
|
226
|
+
result = TextGenerationResult(request=request)
|
|
227
|
+
result.start(request.prompt)
|
|
228
|
+
received_final = False
|
|
229
|
+
|
|
230
|
+
async for response in self.make_request(request):
|
|
231
|
+
logger.debug("Received response: {}", response)
|
|
232
|
+
if response.type_ == "token_iter":
|
|
233
|
+
result.output_token(response.add_token if response.add_token else "")
|
|
234
|
+
elif response.type_ == "final":
|
|
235
|
+
if received_final:
|
|
236
|
+
err = ValueError(
|
|
237
|
+
"Received multiple final responses from the backend."
|
|
238
|
+
)
|
|
239
|
+
logger.error(err)
|
|
240
|
+
raise err
|
|
241
|
+
|
|
242
|
+
result.end(
|
|
243
|
+
output=response.output,
|
|
244
|
+
prompt_token_count=response.prompt_token_count,
|
|
245
|
+
output_token_count=response.output_token_count,
|
|
246
|
+
)
|
|
247
|
+
received_final = True
|
|
248
|
+
else:
|
|
249
|
+
err = ValueError(
|
|
250
|
+
f"Invalid response received from the backend of type: "
|
|
251
|
+
f"{response.type_} for {response}"
|
|
252
|
+
)
|
|
253
|
+
logger.error(err)
|
|
254
|
+
raise err
|
|
255
|
+
|
|
256
|
+
if not received_final:
|
|
257
|
+
err = ValueError("No final response received from the backend.")
|
|
258
|
+
logger.error(err)
|
|
259
|
+
raise err
|
|
260
|
+
|
|
261
|
+
logger.info("Request completed with output: {}", result.output)
|
|
262
|
+
|
|
263
|
+
return result
|
|
264
|
+
|
|
265
|
+
@abstractmethod
|
|
266
|
+
async def make_request(
|
|
267
|
+
self,
|
|
268
|
+
request: TextGenerationRequest,
|
|
269
|
+
) -> AsyncGenerator[GenerativeResponse, None]:
|
|
270
|
+
"""
|
|
271
|
+
Abstract method to make a request to the backend.
|
|
272
|
+
|
|
273
|
+
Subclasses must implement this method to define how requests are handled
|
|
274
|
+
by the backend.
|
|
275
|
+
|
|
276
|
+
:param request: The request object containing the prompt and
|
|
277
|
+
other configurations.
|
|
278
|
+
:type request: TextGenerationRequest
|
|
279
|
+
:yield: A generator yielding responses from the backend.
|
|
280
|
+
:rtype: AsyncGenerator[GenerativeResponse, None]
|
|
281
|
+
"""
|
|
282
|
+
yield None # type: ignore # noqa: PGH003
|
|
283
|
+
|
|
284
|
+
@abstractmethod
|
|
285
|
+
def available_models(self) -> List[str]:
|
|
286
|
+
"""
|
|
287
|
+
Abstract method to get the available models for the backend.
|
|
288
|
+
|
|
289
|
+
Subclasses must implement this method to provide the list of models
|
|
290
|
+
supported by the backend.
|
|
291
|
+
|
|
292
|
+
:return: A list of available models.
|
|
293
|
+
:rtype: List[str]
|
|
294
|
+
:raises NotImplementedError: If the method is not implemented by a subclass.
|
|
295
|
+
"""
|
|
296
|
+
raise NotImplementedError
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
@functools.lru_cache(maxsize=1)
|
|
300
|
+
def _cachable_default_model(backend: Backend) -> str:
|
|
301
|
+
"""
|
|
302
|
+
Get the default model for a backend using LRU caching.
|
|
303
|
+
|
|
304
|
+
This function caches the default model to optimize repeated lookups.
|
|
305
|
+
|
|
306
|
+
:param backend: The backend instance for which to get the default model.
|
|
307
|
+
:type backend: Backend
|
|
308
|
+
:return: The default model.
|
|
309
|
+
:rtype: str
|
|
310
|
+
:raises ValueError: If no models are available.
|
|
311
|
+
"""
|
|
312
|
+
logger.debug("Getting default model for backend: {}", backend)
|
|
313
|
+
models = backend.available_models()
|
|
314
|
+
if models:
|
|
315
|
+
logger.debug("Default model: {}", models[0])
|
|
316
|
+
return models[0]
|
|
317
|
+
|
|
318
|
+
err = ValueError("No models available.")
|
|
319
|
+
logger.error(err)
|
|
320
|
+
raise err
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
from typing import AsyncGenerator, Dict, List, Optional
|
|
2
|
+
|
|
3
|
+
from loguru import logger
|
|
4
|
+
from openai import AsyncOpenAI, OpenAI
|
|
5
|
+
|
|
6
|
+
from guidellm.backend.base import Backend, GenerativeResponse
|
|
7
|
+
from guidellm.config import settings
|
|
8
|
+
from guidellm.core import TextGenerationRequest
|
|
9
|
+
|
|
10
|
+
__all__ = ["OpenAIBackend"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@Backend.register("openai_server")
|
|
14
|
+
class OpenAIBackend(Backend):
|
|
15
|
+
"""
|
|
16
|
+
An OpenAI backend implementation for generative AI results.
|
|
17
|
+
|
|
18
|
+
This class provides an interface to communicate with the
|
|
19
|
+
OpenAI server for generating responses based on given prompts.
|
|
20
|
+
|
|
21
|
+
:param openai_api_key: The API key for OpenAI.
|
|
22
|
+
If not provided, it will default to the key from settings.
|
|
23
|
+
:type openai_api_key: Optional[str]
|
|
24
|
+
:param target: The target URL string for the OpenAI server.
|
|
25
|
+
:type target: Optional[str]
|
|
26
|
+
:param model: The OpenAI model to use, defaults to the first available model.
|
|
27
|
+
:type model: Optional[str]
|
|
28
|
+
:param request_args: Additional arguments for the OpenAI request.
|
|
29
|
+
:type request_args: Dict[str, Any]
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
openai_api_key: Optional[str] = None,
|
|
35
|
+
target: Optional[str] = None,
|
|
36
|
+
model: Optional[str] = None,
|
|
37
|
+
**request_args,
|
|
38
|
+
):
|
|
39
|
+
self._request_args: Dict = request_args
|
|
40
|
+
api_key: str = openai_api_key or settings.openai.api_key
|
|
41
|
+
|
|
42
|
+
if not api_key:
|
|
43
|
+
err = ValueError(
|
|
44
|
+
"`GUIDELLM__OPENAI__API_KEY` environment variable or "
|
|
45
|
+
"--openai-api-key CLI parameter must be specified for the "
|
|
46
|
+
"OpenAI backend."
|
|
47
|
+
)
|
|
48
|
+
logger.error("{}", err)
|
|
49
|
+
raise err
|
|
50
|
+
|
|
51
|
+
base_url = target or settings.openai.base_url
|
|
52
|
+
|
|
53
|
+
if not base_url:
|
|
54
|
+
err = ValueError(
|
|
55
|
+
"`GUIDELLM__OPENAI__BASE_URL` environment variable or "
|
|
56
|
+
"target parameter must be specified for the OpenAI backend."
|
|
57
|
+
)
|
|
58
|
+
logger.error("{}", err)
|
|
59
|
+
raise err
|
|
60
|
+
|
|
61
|
+
self._async_client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
|
62
|
+
self._client = OpenAI(api_key=api_key, base_url=base_url)
|
|
63
|
+
self._model = model or self.default_model
|
|
64
|
+
|
|
65
|
+
super().__init__(type_="openai_server", target=base_url, model=self._model)
|
|
66
|
+
logger.info("OpenAI {} Backend listening on {}", self._model, base_url)
|
|
67
|
+
|
|
68
|
+
async def make_request(
|
|
69
|
+
self,
|
|
70
|
+
request: TextGenerationRequest,
|
|
71
|
+
) -> AsyncGenerator[GenerativeResponse, None]:
|
|
72
|
+
"""
|
|
73
|
+
Make a request to the OpenAI backend.
|
|
74
|
+
|
|
75
|
+
This method sends a prompt to the OpenAI backend and streams
|
|
76
|
+
the response tokens back.
|
|
77
|
+
|
|
78
|
+
:param request: The text generation request to submit.
|
|
79
|
+
:type request: TextGenerationRequest
|
|
80
|
+
:yield: A stream of GenerativeResponse objects.
|
|
81
|
+
:rtype: AsyncGenerator[GenerativeResponse, None]
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
logger.debug("Making request to OpenAI backend with prompt: {}", request.prompt)
|
|
85
|
+
|
|
86
|
+
request_args: Dict = {
|
|
87
|
+
"n": 1, # Number of completions for each prompt
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
if request.output_token_count is not None:
|
|
91
|
+
request_args.update(
|
|
92
|
+
{
|
|
93
|
+
"max_tokens": request.output_token_count,
|
|
94
|
+
"stop": None,
|
|
95
|
+
}
|
|
96
|
+
)
|
|
97
|
+
elif settings.openai.max_gen_tokens and settings.openai.max_gen_tokens > 0:
|
|
98
|
+
request_args.update(
|
|
99
|
+
{
|
|
100
|
+
"max_tokens": settings.openai.max_gen_tokens,
|
|
101
|
+
}
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
request_args.update(self._request_args)
|
|
105
|
+
|
|
106
|
+
stream = await self._async_client.chat.completions.create(
|
|
107
|
+
model=self.model,
|
|
108
|
+
messages=[
|
|
109
|
+
{"role": "system", "content": request.prompt},
|
|
110
|
+
],
|
|
111
|
+
stream=True,
|
|
112
|
+
**request_args,
|
|
113
|
+
)
|
|
114
|
+
token_count = 0
|
|
115
|
+
async for chunk in stream:
|
|
116
|
+
choice = chunk.choices[0]
|
|
117
|
+
token = choice.delta.content or ""
|
|
118
|
+
|
|
119
|
+
if choice.finish_reason is not None:
|
|
120
|
+
yield GenerativeResponse(
|
|
121
|
+
type_="final",
|
|
122
|
+
prompt=request.prompt,
|
|
123
|
+
prompt_token_count=request.prompt_token_count,
|
|
124
|
+
output_token_count=token_count,
|
|
125
|
+
)
|
|
126
|
+
break
|
|
127
|
+
|
|
128
|
+
token_count += 1
|
|
129
|
+
yield GenerativeResponse(
|
|
130
|
+
type_="token_iter",
|
|
131
|
+
add_token=token,
|
|
132
|
+
prompt=request.prompt,
|
|
133
|
+
prompt_token_count=request.prompt_token_count,
|
|
134
|
+
output_token_count=token_count,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def available_models(self) -> List[str]:
|
|
138
|
+
"""
|
|
139
|
+
Get the available models for the backend.
|
|
140
|
+
|
|
141
|
+
This method queries the OpenAI API to retrieve a list of available models.
|
|
142
|
+
|
|
143
|
+
:return: A list of available models.
|
|
144
|
+
:rtype: List[str]
|
|
145
|
+
:raises openai.OpenAIError: If an error occurs while retrieving models.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
try:
|
|
149
|
+
return [model.id for model in self._client.models.list().data]
|
|
150
|
+
except Exception as error:
|
|
151
|
+
logger.error("Failed to retrieve available models: {}", error)
|
|
152
|
+
raise error
|
|
153
|
+
|
|
154
|
+
def validate_connection(self):
|
|
155
|
+
"""
|
|
156
|
+
Validate the connection to the OpenAI backend.
|
|
157
|
+
|
|
158
|
+
This method checks that the OpenAI backend is reachable and
|
|
159
|
+
the API key is valid.
|
|
160
|
+
|
|
161
|
+
:raises openai.OpenAIError: If the connection is invalid.
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
try:
|
|
165
|
+
self._client.models.list()
|
|
166
|
+
except Exception as error:
|
|
167
|
+
logger.error("Failed to validate OpenAI connection: {}", error)
|
|
168
|
+
raise error
|