py-adtools 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adtools/__init__.py +1 -0
- adtools/cli.py +61 -0
- adtools/evaluator/__init__.py +2 -0
- adtools/evaluator/auto_server.py +258 -0
- adtools/evaluator/py_evaluator.py +170 -0
- adtools/evaluator/py_evaluator_ray.py +110 -0
- adtools/lm/__init__.py +4 -0
- adtools/lm/lm_base.py +63 -0
- adtools/lm/openai_api.py +118 -0
- adtools/lm/sglang_server.py +423 -0
- adtools/lm/vllm_server.py +452 -0
- adtools/py_code.py +577 -0
- adtools/sandbox/__init__.py +2 -0
- adtools/sandbox/sandbox_executor.py +244 -0
- adtools/sandbox/sandbox_executor_ray.py +194 -0
- adtools/sandbox/utils.py +32 -0
- py_adtools-0.3.2.dist-info/METADATA +567 -0
- py_adtools-0.3.2.dist-info/RECORD +22 -0
- py_adtools-0.3.2.dist-info/WHEEL +5 -0
- py_adtools-0.3.2.dist-info/entry_points.txt +2 -0
- py_adtools-0.3.2.dist-info/licenses/LICENSE +21 -0
- py_adtools-0.3.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,452 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2025 Rui Zhang <rzhang.cs@gmail.com>
|
|
3
|
+
|
|
4
|
+
NOTICE: This code is under MIT license. This code is intended for academic/research purposes only.
|
|
5
|
+
Commercial use of this software or its derivatives requires prior written permission.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Optional, List, Literal, Dict, Any
|
|
9
|
+
import os
|
|
10
|
+
import subprocess
|
|
11
|
+
import sys
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
import psutil
|
|
14
|
+
import time
|
|
15
|
+
|
|
16
|
+
import openai.types.chat
|
|
17
|
+
import requests
|
|
18
|
+
|
|
19
|
+
from adtools.lm.lm_base import LanguageModel
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _print_cmd_list(cmd_list, gpus, host, port):
|
|
23
|
+
print("\n" + "=" * 80)
|
|
24
|
+
print(f"[vLLM] Launching vLLM on GPU:{gpus}; URL: https://{host}:{port}")
|
|
25
|
+
print("=" * 80)
|
|
26
|
+
cmd = cmd_list[0] + " \\\n"
|
|
27
|
+
for c in cmd_list[1:]:
|
|
28
|
+
cmd += " " + c + " \\\n"
|
|
29
|
+
print(cmd.strip())
|
|
30
|
+
print("=" * 80 + "\n", flush=True)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class VLLMServer(LanguageModel):
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
model_path: str,
|
|
37
|
+
port: int,
|
|
38
|
+
gpus: int | list[int],
|
|
39
|
+
tokenizer_path: Optional[str] = None,
|
|
40
|
+
max_model_len: int = 16384,
|
|
41
|
+
max_lora_rank: Optional[int] = None,
|
|
42
|
+
host: str = "0.0.0.0",
|
|
43
|
+
mem_util: float = 0.85,
|
|
44
|
+
deploy_timeout_seconds: int = 600,
|
|
45
|
+
*,
|
|
46
|
+
launch_vllm_in_init=True,
|
|
47
|
+
enforce_eager: bool = False,
|
|
48
|
+
vllm_log_level: Literal[
|
|
49
|
+
"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"
|
|
50
|
+
] = "INFO",
|
|
51
|
+
silent_mode: bool = False,
|
|
52
|
+
env_variable_dict: Optional[Dict[str, str]] = None,
|
|
53
|
+
vllm_serve_args: Optional[List[str]] = None,
|
|
54
|
+
vllm_serve_kwargs: Optional[Dict[str, str]] = None,
|
|
55
|
+
chat_template_kwargs: Optional[Dict[str, Any]] = None,
|
|
56
|
+
):
|
|
57
|
+
"""Deploy an LLM on specified GPUs.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
model_path: Path to the model to deploy.
|
|
61
|
+
tokenizer_path: Path to the tokenizer to use.
|
|
62
|
+
port: List of ports to deploy.
|
|
63
|
+
gpus: List of GPUs to deploy.
|
|
64
|
+
max_lora_rank: Max rank of LoRA adapter. Defaults to `None` which disables LoRA adapter.
|
|
65
|
+
host: Host address for vLLM server.
|
|
66
|
+
mem_util: Memory utility for each vLLM deployment.
|
|
67
|
+
deploy_timeout_seconds: Timeout to deploy (in seconds).
|
|
68
|
+
launch_vllm_in_init: Launch vLLM during initialization of this class.
|
|
69
|
+
enforce_eager: Enforce eager mode.
|
|
70
|
+
vllm_log_level: Log level of vLLM server.
|
|
71
|
+
silent_mode: Silent mode.
|
|
72
|
+
env_variable_dict: Environment variables to use for vLLM server, e.g., {'KEY': 'VALUE'}.
|
|
73
|
+
vllm_serve_args: Arguments to pass to vLLM server, e.g., ['--enable-reasoning'].
|
|
74
|
+
vllm_serve_kwargs: Keyword arguments to pass to vLLM server, e.g., {'--reasoning-parser': 'deepseek-r1'}.
|
|
75
|
+
chat_template_kwargs: Keyword arguments to pass to chat template, e.g., {'enable_thinking': False}.
|
|
76
|
+
|
|
77
|
+
Example:
|
|
78
|
+
# deploy a model on GPU 0 and 1
|
|
79
|
+
llm = VLLMServer(
|
|
80
|
+
model_path='path/to/model',
|
|
81
|
+
tokenizer_path='path/to/tokenizer',
|
|
82
|
+
gpus=[0, 1], # set gpus=0 or gpus=[0] if you only use one GPU
|
|
83
|
+
port=12001,
|
|
84
|
+
mem_util=0.8
|
|
85
|
+
)
|
|
86
|
+
# draw sample using base model
|
|
87
|
+
llm.draw_sample('hello')
|
|
88
|
+
|
|
89
|
+
# load adapter and draw sample
|
|
90
|
+
llm.load_lora_adapter('adapter_1', '/path/to/adapter')
|
|
91
|
+
llm.draw_sample('hello', lora_name='adapter_1')
|
|
92
|
+
|
|
93
|
+
# unload adapter
|
|
94
|
+
llm.unload_lora_adapter('adapter_1')
|
|
95
|
+
|
|
96
|
+
# release resources
|
|
97
|
+
llm.close()
|
|
98
|
+
"""
|
|
99
|
+
self._model_path = model_path
|
|
100
|
+
self._port = port
|
|
101
|
+
self._gpus = gpus
|
|
102
|
+
self._tokenizer_path = (
|
|
103
|
+
tokenizer_path if tokenizer_path is not None else model_path
|
|
104
|
+
)
|
|
105
|
+
self._max_model_len = max_model_len
|
|
106
|
+
self._max_lora_rank = max_lora_rank
|
|
107
|
+
self._host = host
|
|
108
|
+
self._mem_util = mem_util
|
|
109
|
+
self._deploy_timeout_seconds = deploy_timeout_seconds
|
|
110
|
+
self._enforce_eager = enforce_eager
|
|
111
|
+
self._vllm_log_level = vllm_log_level
|
|
112
|
+
self._silent_mode = silent_mode
|
|
113
|
+
self._env_variable_dict = env_variable_dict
|
|
114
|
+
self._vllm_serve_args = vllm_serve_args
|
|
115
|
+
self._vllm_serve_kwargs = vllm_serve_kwargs
|
|
116
|
+
self._chat_template_kwargs = chat_template_kwargs
|
|
117
|
+
self._vllm_server_process = None
|
|
118
|
+
|
|
119
|
+
# Deploy vLLMs
|
|
120
|
+
if launch_vllm_in_init:
|
|
121
|
+
self.launch_vllm_server()
|
|
122
|
+
|
|
123
|
+
def launch_vllm_server(self, detach: bool = False, skip_if_running: bool = False):
|
|
124
|
+
try:
|
|
125
|
+
import vllm
|
|
126
|
+
except ImportError:
|
|
127
|
+
raise
|
|
128
|
+
|
|
129
|
+
if skip_if_running and self._is_server_running():
|
|
130
|
+
print(
|
|
131
|
+
f"[vLLM] Server already running on http://{self._host}:{self._port}. "
|
|
132
|
+
f"Skipping launch."
|
|
133
|
+
)
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
self._detached = detach
|
|
137
|
+
self._vllm_server_process = self._launch_vllm(detach=detach)
|
|
138
|
+
self._wait_for_vllm()
|
|
139
|
+
|
|
140
|
+
def _launch_vllm(self, detach: bool = False):
|
|
141
|
+
"""Launch a vLLM server and return the subprocess."""
|
|
142
|
+
if isinstance(self._gpus, int):
|
|
143
|
+
gpus = str(self._gpus)
|
|
144
|
+
else:
|
|
145
|
+
gpus = ",".join([str(g) for g in self._gpus])
|
|
146
|
+
|
|
147
|
+
executable_path = sys.executable
|
|
148
|
+
cmd = [
|
|
149
|
+
executable_path,
|
|
150
|
+
"-m",
|
|
151
|
+
"vllm.entrypoints.openai.api_server",
|
|
152
|
+
"--model",
|
|
153
|
+
self._model_path,
|
|
154
|
+
"--tokenizer",
|
|
155
|
+
self._tokenizer_path,
|
|
156
|
+
"--max_model_len",
|
|
157
|
+
str(self._max_model_len),
|
|
158
|
+
"--host",
|
|
159
|
+
self._host,
|
|
160
|
+
"--port",
|
|
161
|
+
str(self._port),
|
|
162
|
+
"--gpu-memory-utilization",
|
|
163
|
+
str(self._mem_util),
|
|
164
|
+
"--tensor-parallel-size",
|
|
165
|
+
str(len(self._gpus)) if isinstance(self._gpus, list) else "1",
|
|
166
|
+
"--trust-remote-code",
|
|
167
|
+
"--chat-template-content-format",
|
|
168
|
+
"string",
|
|
169
|
+
]
|
|
170
|
+
|
|
171
|
+
if self._enforce_eager:
|
|
172
|
+
cmd.append("--enforce_eager")
|
|
173
|
+
|
|
174
|
+
# Other args for vllm serve
|
|
175
|
+
if self._vllm_serve_args is not None:
|
|
176
|
+
for arg in self._vllm_serve_args:
|
|
177
|
+
cmd.append(arg)
|
|
178
|
+
|
|
179
|
+
# Other kwargs for vllm serve
|
|
180
|
+
if self._vllm_serve_kwargs is not None:
|
|
181
|
+
for kwarg, value in self._vllm_serve_kwargs.items():
|
|
182
|
+
cmd.extend([kwarg, value])
|
|
183
|
+
|
|
184
|
+
# Environmental variables
|
|
185
|
+
env = os.environ.copy()
|
|
186
|
+
env["CUDA_VISIBLE_DEVICES"] = gpus
|
|
187
|
+
env["VLLM_LOGGING_LEVEL"] = self._vllm_log_level
|
|
188
|
+
|
|
189
|
+
# FIXME: These code are required for my machine :(
|
|
190
|
+
# FIXME: This may due to the bad NCCL configuration :(
|
|
191
|
+
if isinstance(self._gpus, list) and len(self._gpus) > 1:
|
|
192
|
+
# set NCCL environment variable
|
|
193
|
+
env["NCCL_P2P_DISABLE"] = "1"
|
|
194
|
+
# disable custom all reduce
|
|
195
|
+
cmd.append("--disable-custom-all-reduce")
|
|
196
|
+
|
|
197
|
+
# Enable LoRA dynamic loading
|
|
198
|
+
if self._max_lora_rank is not None:
|
|
199
|
+
cmd.extend(
|
|
200
|
+
[
|
|
201
|
+
"--enable-lora",
|
|
202
|
+
"--max-lora-rank",
|
|
203
|
+
str(self._max_lora_rank),
|
|
204
|
+
]
|
|
205
|
+
)
|
|
206
|
+
env["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True"
|
|
207
|
+
|
|
208
|
+
# Other env variables
|
|
209
|
+
if self._env_variable_dict is not None:
|
|
210
|
+
for k, v in self._env_variable_dict.items():
|
|
211
|
+
env[k] = v
|
|
212
|
+
|
|
213
|
+
_print_cmd_list(cmd, gpus=self._gpus, host=self._host, port=self._port)
|
|
214
|
+
|
|
215
|
+
# Launch vllm using subprocess
|
|
216
|
+
stdout = Path(os.devnull).open("w") if self._silent_mode else None
|
|
217
|
+
preexec_fn = os.setsid if detach and sys.platform != "win32" else None
|
|
218
|
+
proc = subprocess.Popen(
|
|
219
|
+
cmd, env=env, stdout=stdout, stderr=subprocess.STDOUT, preexec_fn=preexec_fn
|
|
220
|
+
)
|
|
221
|
+
return proc
|
|
222
|
+
|
|
223
|
+
def _kill_vllm_process(self):
|
|
224
|
+
if getattr(self, "_detached", False):
|
|
225
|
+
print(
|
|
226
|
+
f"[vLLM] Server on port {self._port} is detached. Not killing process."
|
|
227
|
+
)
|
|
228
|
+
return
|
|
229
|
+
|
|
230
|
+
try:
|
|
231
|
+
# Get child processes before terminating parent
|
|
232
|
+
try:
|
|
233
|
+
parent = psutil.Process(self._vllm_server_process.pid)
|
|
234
|
+
children = parent.children(recursive=True)
|
|
235
|
+
except psutil.NoSuchProcess:
|
|
236
|
+
children = []
|
|
237
|
+
|
|
238
|
+
# Terminate parent process
|
|
239
|
+
self._vllm_server_process.terminate()
|
|
240
|
+
self._vllm_server_process.wait(timeout=5)
|
|
241
|
+
print(f"[vLLM] terminated process: {self._vllm_server_process.pid}")
|
|
242
|
+
|
|
243
|
+
# Kill any remaining children
|
|
244
|
+
for child in children:
|
|
245
|
+
try:
|
|
246
|
+
child.terminate()
|
|
247
|
+
child.wait(timeout=2)
|
|
248
|
+
except (psutil.NoSuchProcess, psutil.TimeoutExpired):
|
|
249
|
+
try:
|
|
250
|
+
child.kill()
|
|
251
|
+
except psutil.NoSuchProcess:
|
|
252
|
+
pass
|
|
253
|
+
except subprocess.TimeoutExpired:
|
|
254
|
+
self._vllm_server_process.kill()
|
|
255
|
+
print(f"[vLLM] killed process: {self._vllm_server_process.pid}")
|
|
256
|
+
|
|
257
|
+
def _is_server_running(self):
|
|
258
|
+
"""Check if a vLLM server is already running on the given host and port."""
|
|
259
|
+
health = f"http://{self._host}:{self._port}/health"
|
|
260
|
+
try:
|
|
261
|
+
if requests.get(health, timeout=1).status_code == 200:
|
|
262
|
+
return True
|
|
263
|
+
except requests.exceptions.RequestException:
|
|
264
|
+
pass
|
|
265
|
+
return False
|
|
266
|
+
|
|
267
|
+
def _wait_for_vllm(self):
|
|
268
|
+
"""Check each vLLM server's state and check /health. Kill all vLLM server processes if timeout."""
|
|
269
|
+
for _ in range(self._deploy_timeout_seconds):
|
|
270
|
+
# check process status
|
|
271
|
+
if self._vllm_server_process.poll() is not None:
|
|
272
|
+
sys.exit(
|
|
273
|
+
f"[vLLM] crashed (exit {self._vllm_server_process.returncode})"
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# check server status
|
|
277
|
+
if self._is_server_running():
|
|
278
|
+
return
|
|
279
|
+
time.sleep(1)
|
|
280
|
+
|
|
281
|
+
# Servers fail to initialize
|
|
282
|
+
print("[vLLM] failed to start within timeout")
|
|
283
|
+
self._kill_vllm_process()
|
|
284
|
+
sys.exit("[vLLM] failed to start within timeout")
|
|
285
|
+
|
|
286
|
+
def unload_lora_adapter(self, lora_name: str):
|
|
287
|
+
"""Unload lora adapter given the lora name.
|
|
288
|
+
Args:
|
|
289
|
+
lora_name: Lora adapter name.
|
|
290
|
+
"""
|
|
291
|
+
lora_api_url = f"http://{self._host}:{self._port}/v1/unload_lora_adapter"
|
|
292
|
+
headers = {"Content-Type": "application/json"}
|
|
293
|
+
try:
|
|
294
|
+
payload = {"lora_name": lora_name}
|
|
295
|
+
requests.post(lora_api_url, json=payload, headers=headers, timeout=10)
|
|
296
|
+
except requests.exceptions.RequestException:
|
|
297
|
+
pass
|
|
298
|
+
|
|
299
|
+
def load_lora_adapter(
|
|
300
|
+
self, lora_name: str, new_adapter_path: str, num_trails: int = 5
|
|
301
|
+
):
|
|
302
|
+
"""Dynamically load a LoRA adapter.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
lora_name: LoRA adapter name.
|
|
306
|
+
new_adapter_path: Path to the new LoRA adapter weights.
|
|
307
|
+
"""
|
|
308
|
+
# First unload lora adapter
|
|
309
|
+
self.unload_lora_adapter(lora_name)
|
|
310
|
+
|
|
311
|
+
if self._max_lora_rank is None:
|
|
312
|
+
raise ValueError(
|
|
313
|
+
'LoRA is not enabled for this VLLMServer instance, since "max_lora_rank" is not set.'
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# Prepare the payload for LoRA update
|
|
317
|
+
payload = {"lora_name": lora_name, "lora_path": new_adapter_path}
|
|
318
|
+
headers = {"Content-Type": "application/json"}
|
|
319
|
+
lora_api_url = f"http://{self._host}:{self._port}/v1/load_lora_adapter"
|
|
320
|
+
|
|
321
|
+
# Repeatedly trying to load lora adapters
|
|
322
|
+
for i in range(num_trails):
|
|
323
|
+
try:
|
|
324
|
+
response = requests.post(
|
|
325
|
+
lora_api_url, json=payload, headers=headers, timeout=60
|
|
326
|
+
)
|
|
327
|
+
if response.status_code == 200:
|
|
328
|
+
print(
|
|
329
|
+
f"[vLLM] Successfully load LoRA adapter: {lora_name} from {new_adapter_path}"
|
|
330
|
+
)
|
|
331
|
+
else:
|
|
332
|
+
print(
|
|
333
|
+
f"[vLLM] Failed to load LoRA adapter. "
|
|
334
|
+
f"Status code: {response.status_code}, Response: {response.text}"
|
|
335
|
+
)
|
|
336
|
+
return True
|
|
337
|
+
except requests.exceptions.RequestException:
|
|
338
|
+
continue
|
|
339
|
+
|
|
340
|
+
print(f"[vLLM] Error loading LoRA adapter.")
|
|
341
|
+
return False
|
|
342
|
+
|
|
343
|
+
def close(self):
|
|
344
|
+
"""Shut down vLLM server and kill all vLLM processes."""
|
|
345
|
+
if self._vllm_server_process is not None:
|
|
346
|
+
self._kill_vllm_process()
|
|
347
|
+
|
|
348
|
+
def reload(self):
|
|
349
|
+
"""Reload the vllm server."""
|
|
350
|
+
self.launch_vllm_server()
|
|
351
|
+
|
|
352
|
+
def chat_completion(
|
|
353
|
+
self,
|
|
354
|
+
message: str | List[openai.types.chat.ChatCompletionMessageParam],
|
|
355
|
+
max_tokens: Optional[int] = None,
|
|
356
|
+
timeout_seconds: Optional[int] = None,
|
|
357
|
+
lora_name: Optional[str] = None,
|
|
358
|
+
temperature: float = 0.9,
|
|
359
|
+
top_p: float = 0.9,
|
|
360
|
+
chat_template_kwargs: Optional[Dict[str, Any]] = None,
|
|
361
|
+
) -> str:
|
|
362
|
+
"""Send a chat completion query with OpenAI format to the vLLM server.
|
|
363
|
+
Return the response content.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
message: The message in str or openai format.
|
|
367
|
+
max_tokens: The maximum number of tokens to generate.
|
|
368
|
+
timeout_seconds: The timeout seconds.
|
|
369
|
+
lora_name: Lora adapter name. Defaults to None which uses base model.
|
|
370
|
+
temperature: The temperature parameter.
|
|
371
|
+
top_p: The top p parameter.
|
|
372
|
+
chat_template_kwargs: The chat template kwargs, e.g., {'enable_thinking': False}.
|
|
373
|
+
"""
|
|
374
|
+
data = {
|
|
375
|
+
"messages": [
|
|
376
|
+
(
|
|
377
|
+
{"role": "user", "content": message.strip()}
|
|
378
|
+
if isinstance(message, str)
|
|
379
|
+
else message
|
|
380
|
+
)
|
|
381
|
+
],
|
|
382
|
+
"temperature": temperature,
|
|
383
|
+
"top_p": top_p,
|
|
384
|
+
"max_tokens": max_tokens,
|
|
385
|
+
}
|
|
386
|
+
data: Dict[str, Any]
|
|
387
|
+
|
|
388
|
+
# Use the specified lora adapter
|
|
389
|
+
if lora_name is not None:
|
|
390
|
+
data["model"] = lora_name
|
|
391
|
+
# Chat template keyword args
|
|
392
|
+
if self._chat_template_kwargs is not None:
|
|
393
|
+
data["chat_template_kwargs"] = self._chat_template_kwargs
|
|
394
|
+
elif chat_template_kwargs is not None:
|
|
395
|
+
data["chat_template_kwargs"] = chat_template_kwargs
|
|
396
|
+
# Request
|
|
397
|
+
url = f"http://{self._host}:{self._port}/v1/chat/completions"
|
|
398
|
+
headers = {"Content-Type": "application/json"}
|
|
399
|
+
response = requests.post(
|
|
400
|
+
url, headers=headers, json=data, timeout=timeout_seconds
|
|
401
|
+
)
|
|
402
|
+
return response.json()["choices"][0]["message"]["content"]
|
|
403
|
+
|
|
404
|
+
def embedding(
|
|
405
|
+
self,
|
|
406
|
+
text: str | List[str],
|
|
407
|
+
dimensions: Optional[int] = None,
|
|
408
|
+
timeout_seconds: Optional[float] = None,
|
|
409
|
+
lora_name: Optional[str] = None,
|
|
410
|
+
**kwargs,
|
|
411
|
+
) -> List[float] | List[List[float]]:
|
|
412
|
+
"""Generate embeddings for the given text(s).
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
text: The text or a list of texts to embed.
|
|
416
|
+
dimensions: The number of dimensions for the output embeddings.
|
|
417
|
+
timeout_seconds: The timeout seconds.
|
|
418
|
+
lora_name: Lora adapter name. Defaults to None which uses base model.
|
|
419
|
+
|
|
420
|
+
Returns:
|
|
421
|
+
The embedding for the text, or a list of embeddings for the list of texts.
|
|
422
|
+
"""
|
|
423
|
+
is_str_input = isinstance(text, str)
|
|
424
|
+
if is_str_input:
|
|
425
|
+
text = [text]
|
|
426
|
+
|
|
427
|
+
# Prepare arguments for the API call
|
|
428
|
+
data = {"input": text}
|
|
429
|
+
data: Dict[str, Any]
|
|
430
|
+
|
|
431
|
+
if dimensions is not None:
|
|
432
|
+
data["dimensions"] = dimensions
|
|
433
|
+
|
|
434
|
+
if lora_name is not None:
|
|
435
|
+
data["model"] = lora_name
|
|
436
|
+
|
|
437
|
+
data.update(kwargs)
|
|
438
|
+
|
|
439
|
+
url = f"http://{self._host}:{self._port}/v1/embeddings"
|
|
440
|
+
headers = {"Content-Type": "application/json"}
|
|
441
|
+
|
|
442
|
+
response = requests.post(
|
|
443
|
+
url, headers=headers, json=data, timeout=timeout_seconds
|
|
444
|
+
)
|
|
445
|
+
response.raise_for_status()
|
|
446
|
+
|
|
447
|
+
response_json = response.json()
|
|
448
|
+
embeddings = [item["embedding"] for item in response_json["data"]]
|
|
449
|
+
|
|
450
|
+
if is_str_input:
|
|
451
|
+
return embeddings[0]
|
|
452
|
+
return embeddings
|