speedy-utils 1.1.27__py3-none-any.whl → 1.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +16 -4
- llm_utils/chat_format/__init__.py +10 -10
- llm_utils/chat_format/display.py +33 -21
- llm_utils/chat_format/transform.py +17 -19
- llm_utils/chat_format/utils.py +6 -4
- llm_utils/group_messages.py +17 -14
- llm_utils/lm/__init__.py +6 -5
- llm_utils/lm/async_lm/__init__.py +1 -0
- llm_utils/lm/async_lm/_utils.py +10 -9
- llm_utils/lm/async_lm/async_llm_task.py +141 -137
- llm_utils/lm/async_lm/async_lm.py +48 -42
- llm_utils/lm/async_lm/async_lm_base.py +59 -60
- llm_utils/lm/async_lm/lm_specific.py +4 -3
- llm_utils/lm/base_prompt_builder.py +93 -70
- llm_utils/lm/llm.py +126 -108
- llm_utils/lm/llm_signature.py +4 -2
- llm_utils/lm/lm_base.py +72 -73
- llm_utils/lm/mixins.py +102 -62
- llm_utils/lm/openai_memoize.py +124 -87
- llm_utils/lm/signature.py +105 -92
- llm_utils/lm/utils.py +42 -23
- llm_utils/scripts/vllm_load_balancer.py +23 -30
- llm_utils/scripts/vllm_serve.py +8 -7
- llm_utils/vector_cache/__init__.py +9 -3
- llm_utils/vector_cache/cli.py +1 -1
- llm_utils/vector_cache/core.py +59 -63
- llm_utils/vector_cache/types.py +7 -5
- llm_utils/vector_cache/utils.py +12 -8
- speedy_utils/__imports.py +244 -0
- speedy_utils/__init__.py +90 -194
- speedy_utils/all.py +125 -227
- speedy_utils/common/clock.py +37 -42
- speedy_utils/common/function_decorator.py +6 -12
- speedy_utils/common/logger.py +43 -52
- speedy_utils/common/notebook_utils.py +13 -21
- speedy_utils/common/patcher.py +21 -17
- speedy_utils/common/report_manager.py +42 -44
- speedy_utils/common/utils_cache.py +152 -169
- speedy_utils/common/utils_io.py +137 -103
- speedy_utils/common/utils_misc.py +15 -21
- speedy_utils/common/utils_print.py +22 -28
- speedy_utils/multi_worker/process.py +66 -79
- speedy_utils/multi_worker/thread.py +78 -155
- speedy_utils/scripts/mpython.py +38 -36
- speedy_utils/scripts/openapi_client_codegen.py +10 -10
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.28.dist-info}/METADATA +1 -1
- speedy_utils-1.1.28.dist-info/RECORD +57 -0
- vision_utils/README.md +202 -0
- vision_utils/__init__.py +5 -0
- vision_utils/io_utils.py +470 -0
- vision_utils/plot.py +345 -0
- speedy_utils-1.1.27.dist-info/RECORD +0 -52
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.28.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.28.dist-info}/entry_points.txt +0 -0
llm_utils/lm/utils.py
CHANGED
|
@@ -1,17 +1,17 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import signal
|
|
3
|
-
import time
|
|
4
|
-
from typing import Any, List, Optional, cast
|
|
5
|
-
|
|
6
|
-
from loguru import logger
|
|
7
|
-
|
|
8
2
|
|
|
9
3
|
# Additional imports for VLLM utilities
|
|
10
4
|
import re
|
|
5
|
+
import signal
|
|
11
6
|
import subprocess
|
|
7
|
+
import time
|
|
8
|
+
from typing import Any, List, Optional, cast
|
|
9
|
+
|
|
12
10
|
import requests
|
|
11
|
+
from loguru import logger
|
|
13
12
|
from openai import OpenAI
|
|
14
13
|
|
|
14
|
+
|
|
15
15
|
try:
|
|
16
16
|
import psutil
|
|
17
17
|
|
|
@@ -19,10 +19,12 @@ try:
|
|
|
19
19
|
except ImportError:
|
|
20
20
|
HAS_PSUTIL = False
|
|
21
21
|
psutil = cast(Any, None)
|
|
22
|
-
logger.warning(
|
|
22
|
+
logger.warning(
|
|
23
|
+
"psutil not available. Some VLLM process management features may be limited."
|
|
24
|
+
)
|
|
23
25
|
|
|
24
26
|
# Global tracking of VLLM processes
|
|
25
|
-
_VLLM_PROCESSES:
|
|
27
|
+
_VLLM_PROCESSES: list[subprocess.Popen] = []
|
|
26
28
|
|
|
27
29
|
|
|
28
30
|
def _extract_port_from_vllm_cmd(vllm_cmd: str) -> int:
|
|
@@ -97,7 +99,12 @@ def _start_vllm_server(vllm_cmd: str, timeout: int = 120) -> subprocess.Popen:
|
|
|
97
99
|
|
|
98
100
|
with open(f"/tmp/vllm_{port}.txt", "a") as log_file:
|
|
99
101
|
process = subprocess.Popen(
|
|
100
|
-
cleaned_cmd.split(),
|
|
102
|
+
cleaned_cmd.split(),
|
|
103
|
+
stdout=log_file,
|
|
104
|
+
stderr=subprocess.STDOUT,
|
|
105
|
+
text=True,
|
|
106
|
+
preexec_fn=os.setsid,
|
|
107
|
+
env=env,
|
|
101
108
|
)
|
|
102
109
|
|
|
103
110
|
_VLLM_PROCESSES.append(process)
|
|
@@ -147,7 +154,9 @@ def _kill_vllm_on_port(port: int) -> bool:
|
|
|
147
154
|
proc = psutil.Process(process.pid)
|
|
148
155
|
cmdline = " ".join(proc.cmdline())
|
|
149
156
|
if f"--port {port}" in cmdline or f"--port={port}" in cmdline:
|
|
150
|
-
logger.info(
|
|
157
|
+
logger.info(
|
|
158
|
+
f"Killing tracked VLLM process {process.pid} on port {port}"
|
|
159
|
+
)
|
|
151
160
|
os.killpg(os.getpgid(process.pid), signal.SIGTERM)
|
|
152
161
|
try:
|
|
153
162
|
process.wait(timeout=5)
|
|
@@ -187,8 +196,12 @@ def _kill_vllm_on_port(port: int) -> bool:
|
|
|
187
196
|
for proc in psutil.process_iter(["pid", "cmdline"]):
|
|
188
197
|
try:
|
|
189
198
|
cmdline = " ".join(proc.info["cmdline"] or [])
|
|
190
|
-
if "vllm" in cmdline.lower() and (
|
|
191
|
-
|
|
199
|
+
if "vllm" in cmdline.lower() and (
|
|
200
|
+
f"--port {port}" in cmdline or f"--port={port}" in cmdline
|
|
201
|
+
):
|
|
202
|
+
logger.info(
|
|
203
|
+
f"Killing untracked VLLM process {proc.info['pid']} on port {port}"
|
|
204
|
+
)
|
|
192
205
|
proc.terminate()
|
|
193
206
|
try:
|
|
194
207
|
proc.wait(timeout=5)
|
|
@@ -255,7 +268,9 @@ def _is_server_running(port: int) -> bool:
|
|
|
255
268
|
return False
|
|
256
269
|
|
|
257
270
|
|
|
258
|
-
def get_base_client(
|
|
271
|
+
def get_base_client(
|
|
272
|
+
client=None, cache: bool = True, api_key="abc", vllm_cmd=None, vllm_process=None
|
|
273
|
+
) -> OpenAI:
|
|
259
274
|
"""Get OpenAI client from various inputs."""
|
|
260
275
|
from llm_utils import MOpenAI
|
|
261
276
|
|
|
@@ -264,17 +279,21 @@ def get_base_client(client=None, cache: bool = True, api_key="abc", vllm_cmd=Non
|
|
|
264
279
|
# Parse environment variables from command to get clean command for port extraction
|
|
265
280
|
_, cleaned_cmd = _parse_env_vars_from_cmd(vllm_cmd)
|
|
266
281
|
port = _extract_port_from_vllm_cmd(cleaned_cmd)
|
|
267
|
-
return MOpenAI(
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
282
|
+
return MOpenAI(
|
|
283
|
+
base_url=f"http://localhost:{port}/v1", api_key=api_key, cache=cache
|
|
284
|
+
)
|
|
285
|
+
raise ValueError("Either client or vllm_cmd must be provided.")
|
|
286
|
+
if isinstance(client, int):
|
|
287
|
+
return MOpenAI(
|
|
288
|
+
base_url=f"http://localhost:{client}/v1", api_key=api_key, cache=cache
|
|
289
|
+
)
|
|
290
|
+
if isinstance(client, str):
|
|
273
291
|
return MOpenAI(base_url=client, api_key=api_key, cache=cache)
|
|
274
|
-
|
|
292
|
+
if isinstance(client, OpenAI):
|
|
275
293
|
return MOpenAI(base_url=client.base_url, api_key=api_key, cache=cache)
|
|
276
|
-
|
|
277
|
-
|
|
294
|
+
raise ValueError(
|
|
295
|
+
"Invalid client type. Must be OpenAI, port (int), base_url (str), or None."
|
|
296
|
+
)
|
|
278
297
|
|
|
279
298
|
|
|
280
299
|
def _is_lora_path(path: str) -> bool:
|
|
@@ -285,7 +304,7 @@ def _is_lora_path(path: str) -> bool:
|
|
|
285
304
|
return os.path.isfile(adapter_config_path)
|
|
286
305
|
|
|
287
306
|
|
|
288
|
-
def _get_port_from_client(client: OpenAI) ->
|
|
307
|
+
def _get_port_from_client(client: OpenAI) -> int | None:
|
|
289
308
|
"""Extract port from OpenAI client base_url."""
|
|
290
309
|
if hasattr(client, "base_url") and client.base_url:
|
|
291
310
|
base_url = str(client.base_url)
|
|
@@ -10,9 +10,11 @@ from datetime import datetime
|
|
|
10
10
|
|
|
11
11
|
import aiohttp
|
|
12
12
|
from loguru import logger
|
|
13
|
-
from speedy_utils import setup_logger
|
|
14
13
|
from tabulate import tabulate
|
|
15
14
|
|
|
15
|
+
from speedy_utils import setup_logger
|
|
16
|
+
|
|
17
|
+
|
|
16
18
|
setup_logger(min_interval=5)
|
|
17
19
|
|
|
18
20
|
|
|
@@ -132,10 +134,9 @@ def format_uptime(start_time):
|
|
|
132
134
|
|
|
133
135
|
if hours > 0:
|
|
134
136
|
return f"{hours}h {minutes}m {seconds}s"
|
|
135
|
-
|
|
137
|
+
if minutes > 0:
|
|
136
138
|
return f"{minutes}m {seconds}s"
|
|
137
|
-
|
|
138
|
-
return f"{seconds}s"
|
|
139
|
+
return f"{seconds}s"
|
|
139
140
|
|
|
140
141
|
|
|
141
142
|
def print_banner():
|
|
@@ -247,13 +248,12 @@ async def check_server_health(session, host, port):
|
|
|
247
248
|
)
|
|
248
249
|
await response.release()
|
|
249
250
|
return True
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
except asyncio.TimeoutError:
|
|
251
|
+
logger.debug(
|
|
252
|
+
f"[{LOAD_BALANCER_PORT=}] Health check failed for {url} (Status: {response.status})"
|
|
253
|
+
)
|
|
254
|
+
await response.release()
|
|
255
|
+
return False
|
|
256
|
+
except TimeoutError:
|
|
257
257
|
logger.debug(f"Health check HTTP request timeout for {url}")
|
|
258
258
|
return False
|
|
259
259
|
except aiohttp.ClientConnectorError as e:
|
|
@@ -311,11 +311,11 @@ async def scan_and_update_servers():
|
|
|
311
311
|
|
|
312
312
|
if added:
|
|
313
313
|
logger.info(
|
|
314
|
-
f"Servers added (passed /health check): {sorted(
|
|
314
|
+
f"Servers added (passed /health check): {sorted(added)}"
|
|
315
315
|
)
|
|
316
316
|
if removed:
|
|
317
317
|
logger.info(
|
|
318
|
-
f"Servers removed (failed /health check or stopped): {sorted(
|
|
318
|
+
f"Servers removed (failed /health check or stopped): {sorted(removed)}"
|
|
319
319
|
)
|
|
320
320
|
for server in removed:
|
|
321
321
|
if server in connection_counts:
|
|
@@ -329,7 +329,7 @@ async def scan_and_update_servers():
|
|
|
329
329
|
f"Removed throttling timestamp for unavailable server {server}"
|
|
330
330
|
)
|
|
331
331
|
|
|
332
|
-
available_servers = sorted(
|
|
332
|
+
available_servers = sorted(current_set)
|
|
333
333
|
for server in available_servers:
|
|
334
334
|
if server not in connection_counts:
|
|
335
335
|
connection_counts[server] = 0
|
|
@@ -375,7 +375,9 @@ async def handle_client(client_reader, client_writer):
|
|
|
375
375
|
|
|
376
376
|
min_connections = float("inf")
|
|
377
377
|
least_used_available_servers = []
|
|
378
|
-
for
|
|
378
|
+
for (
|
|
379
|
+
server
|
|
380
|
+
) in (
|
|
379
381
|
available_servers
|
|
380
382
|
): # Iterate only over servers that passed health check
|
|
381
383
|
count = connection_counts.get(server, 0)
|
|
@@ -705,9 +707,9 @@ async def stats_json(request):
|
|
|
705
707
|
{
|
|
706
708
|
"host": BACKEND_HOST,
|
|
707
709
|
"port": port,
|
|
708
|
-
"active_connections":
|
|
709
|
-
|
|
710
|
-
|
|
710
|
+
"active_connections": (
|
|
711
|
+
connection_counts.get(server, 0) if is_online else 0
|
|
712
|
+
),
|
|
711
713
|
"status": "ONLINE" if is_online else "OFFLINE",
|
|
712
714
|
}
|
|
713
715
|
)
|
|
@@ -929,23 +931,14 @@ async def main():
|
|
|
929
931
|
logger.info("Cancelling background tasks...")
|
|
930
932
|
scan_task.cancel()
|
|
931
933
|
status_task.cancel()
|
|
932
|
-
|
|
934
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
933
935
|
await asyncio.gather(scan_task, status_task, return_exceptions=True)
|
|
934
|
-
except asyncio.CancelledError:
|
|
935
|
-
pass
|
|
936
936
|
print(f"{Colors.BRIGHT_GREEN}✅ Shutdown complete. Goodbye!{Colors.RESET}")
|
|
937
937
|
logger.info("Background tasks finished.")
|
|
938
938
|
|
|
939
939
|
|
|
940
940
|
def run_load_balancer():
|
|
941
|
-
global
|
|
942
|
-
LOAD_BALANCER_PORT, \
|
|
943
|
-
BACKEND_PORTS, \
|
|
944
|
-
BACKEND_HOST, \
|
|
945
|
-
STATUS_PRINT_INTERVAL, \
|
|
946
|
-
HEALTH_CHECK_TIMEOUT, \
|
|
947
|
-
THROTTLE_MS, \
|
|
948
|
-
STATS_PORT
|
|
941
|
+
global LOAD_BALANCER_PORT, BACKEND_PORTS, BACKEND_HOST, STATUS_PRINT_INTERVAL, HEALTH_CHECK_TIMEOUT, THROTTLE_MS, STATS_PORT
|
|
949
942
|
args = parse_args()
|
|
950
943
|
LOAD_BALANCER_PORT = args.port
|
|
951
944
|
BACKEND_HOST = args.host
|
|
@@ -976,4 +969,4 @@ def run_load_balancer():
|
|
|
976
969
|
|
|
977
970
|
|
|
978
971
|
if __name__ == "__main__":
|
|
979
|
-
run_load_balancer()
|
|
972
|
+
run_load_balancer()
|
llm_utils/scripts/vllm_serve.py
CHANGED
|
@@ -75,6 +75,7 @@ from loguru import logger
|
|
|
75
75
|
from llm_utils.lm.openai_memoize import MOpenAI
|
|
76
76
|
from speedy_utils.common.utils_io import load_by_ext
|
|
77
77
|
|
|
78
|
+
|
|
78
79
|
LORA_DIR: str = os.environ.get("LORA_DIR", "/loras")
|
|
79
80
|
LORA_DIR = os.path.abspath(LORA_DIR)
|
|
80
81
|
HF_HOME: str = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
|
|
@@ -93,8 +94,8 @@ def add_lora(
|
|
|
93
94
|
lora_name_or_path: str,
|
|
94
95
|
host_port: str,
|
|
95
96
|
url: str = "http://HOST:PORT/v1/load_lora_adapter",
|
|
96
|
-
served_model_name:
|
|
97
|
-
lora_module:
|
|
97
|
+
served_model_name: str | None = None,
|
|
98
|
+
lora_module: str | None = None,
|
|
98
99
|
) -> dict:
|
|
99
100
|
"""Add a LoRA adapter to a running vLLM server."""
|
|
100
101
|
url = url.replace("HOST:PORT", host_port)
|
|
@@ -126,7 +127,7 @@ def add_lora(
|
|
|
126
127
|
return {"error": f"Request failed: {str(e)}"}
|
|
127
128
|
|
|
128
129
|
|
|
129
|
-
def unload_lora(lora_name: str, host_port: str) ->
|
|
130
|
+
def unload_lora(lora_name: str, host_port: str) -> dict | None:
|
|
130
131
|
"""Unload a LoRA adapter from a running vLLM server."""
|
|
131
132
|
try:
|
|
132
133
|
url = f"http://{host_port}/v1/unload_lora_adapter"
|
|
@@ -144,7 +145,7 @@ def unload_lora(lora_name: str, host_port: str) -> Optional[dict]:
|
|
|
144
145
|
def serve(args) -> None:
|
|
145
146
|
"""Start vLLM containers with dynamic args."""
|
|
146
147
|
print("Starting vLLM containers...,")
|
|
147
|
-
gpu_groups_arr:
|
|
148
|
+
gpu_groups_arr: list[str] = args.gpu_groups.split(",")
|
|
148
149
|
vllm_binary: str = get_vllm()
|
|
149
150
|
if args.enable_lora:
|
|
150
151
|
vllm_binary = "VLLM_ALLOW_RUNTIME_LORA_UPDATING=True " + vllm_binary
|
|
@@ -232,9 +233,9 @@ def get_vllm() -> str:
|
|
|
232
233
|
vllm_binary = subprocess.check_output("which vllm", shell=True, text=True).strip()
|
|
233
234
|
vllm_binary = os.getenv("VLLM_BINARY", vllm_binary)
|
|
234
235
|
logger.info(f"vLLM binary: {vllm_binary}")
|
|
235
|
-
assert os.path.exists(
|
|
236
|
-
|
|
237
|
-
)
|
|
236
|
+
assert os.path.exists(
|
|
237
|
+
vllm_binary
|
|
238
|
+
), f"vLLM binary not found at {vllm_binary}, please set VLLM_BINARY env variable"
|
|
238
239
|
return vllm_binary
|
|
239
240
|
|
|
240
241
|
|
|
@@ -11,15 +11,21 @@ Example:
|
|
|
11
11
|
# Using local model
|
|
12
12
|
cache = VectorCache("Qwen/Qwen3-Embedding-0.6B")
|
|
13
13
|
embeddings = cache.embeds(["Hello world", "How are you?"])
|
|
14
|
-
|
|
14
|
+
|
|
15
15
|
# Using OpenAI API
|
|
16
16
|
cache = VectorCache("https://api.openai.com/v1")
|
|
17
17
|
embeddings = cache.embeds(["Hello world", "How are you?"])
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
from .core import VectorCache
|
|
21
|
-
from .utils import get_default_cache_path, validate_model_name
|
|
21
|
+
from .utils import estimate_cache_size, get_default_cache_path, validate_model_name
|
|
22
|
+
|
|
22
23
|
|
|
23
24
|
__version__ = "0.1.0"
|
|
24
25
|
__author__ = "AnhVTH <anhvth.226@gmail.com>"
|
|
25
|
-
__all__ = [
|
|
26
|
+
__all__ = [
|
|
27
|
+
"VectorCache",
|
|
28
|
+
"get_default_cache_path",
|
|
29
|
+
"validate_model_name",
|
|
30
|
+
"estimate_cache_size",
|
|
31
|
+
]
|
llm_utils/vector_cache/cli.py
CHANGED
|
@@ -106,7 +106,7 @@ def handle_embed(args):
|
|
|
106
106
|
if not file_path.exists():
|
|
107
107
|
raise FileNotFoundError(f"File not found: {args.file}")
|
|
108
108
|
|
|
109
|
-
with open(file_path,
|
|
109
|
+
with open(file_path, encoding="utf-8") as f:
|
|
110
110
|
texts.extend([line.strip() for line in f if line.strip()])
|
|
111
111
|
|
|
112
112
|
if not texts:
|
llm_utils/vector_cache/core.py
CHANGED
|
@@ -61,18 +61,18 @@ class VectorCache:
|
|
|
61
61
|
def __init__(
|
|
62
62
|
self,
|
|
63
63
|
url_or_model: str,
|
|
64
|
-
backend:
|
|
65
|
-
embed_size:
|
|
66
|
-
db_path:
|
|
64
|
+
backend: Literal["vllm", "transformers", "openai"] | None = None,
|
|
65
|
+
embed_size: int | None = None,
|
|
66
|
+
db_path: str | None = None,
|
|
67
67
|
# OpenAI API parameters
|
|
68
|
-
api_key:
|
|
69
|
-
model_name:
|
|
68
|
+
api_key: str | None = "abc",
|
|
69
|
+
model_name: str | None = None,
|
|
70
70
|
# vLLM parameters
|
|
71
71
|
vllm_gpu_memory_utilization: float = 0.5,
|
|
72
72
|
vllm_tensor_parallel_size: int = 1,
|
|
73
73
|
vllm_dtype: str = "auto",
|
|
74
74
|
vllm_trust_remote_code: bool = False,
|
|
75
|
-
vllm_max_model_len:
|
|
75
|
+
vllm_max_model_len: int | None = None,
|
|
76
76
|
# Transformers parameters
|
|
77
77
|
transformers_device: str = "auto",
|
|
78
78
|
transformers_batch_size: int = 32,
|
|
@@ -149,7 +149,6 @@ class VectorCache:
|
|
|
149
149
|
if self.verbose:
|
|
150
150
|
print(f"Model auto-detection failed: {e}, using default model")
|
|
151
151
|
# Fallback to default if auto-detection fails
|
|
152
|
-
pass
|
|
153
152
|
|
|
154
153
|
# Set default db_path if not provided
|
|
155
154
|
if db_path is None:
|
|
@@ -185,7 +184,7 @@ class VectorCache:
|
|
|
185
184
|
print(f"✓ {self.backend.upper()} model/client loaded successfully")
|
|
186
185
|
|
|
187
186
|
def _determine_backend(
|
|
188
|
-
self, backend:
|
|
187
|
+
self, backend: Literal["vllm", "transformers", "openai"] | None
|
|
189
188
|
) -> str:
|
|
190
189
|
"""Determine the appropriate backend based on url_or_model and user preference."""
|
|
191
190
|
if backend is not None:
|
|
@@ -202,7 +201,7 @@ class VectorCache:
|
|
|
202
201
|
# Default to vllm for local models
|
|
203
202
|
return "vllm"
|
|
204
203
|
|
|
205
|
-
def _try_infer_model_name(self, model_name:
|
|
204
|
+
def _try_infer_model_name(self, model_name: str | None) -> str | None:
|
|
206
205
|
"""Infer model name for OpenAI backend if not explicitly provided."""
|
|
207
206
|
if model_name:
|
|
208
207
|
return model_name
|
|
@@ -243,17 +242,21 @@ class VectorCache:
|
|
|
243
242
|
) # Checkpoint WAL every 1000 pages
|
|
244
243
|
|
|
245
244
|
def _ensure_schema(self) -> None:
|
|
246
|
-
self.conn.execute(
|
|
245
|
+
self.conn.execute(
|
|
246
|
+
"""
|
|
247
247
|
CREATE TABLE IF NOT EXISTS cache (
|
|
248
248
|
hash TEXT PRIMARY KEY,
|
|
249
249
|
text TEXT,
|
|
250
250
|
embedding BLOB
|
|
251
251
|
)
|
|
252
|
-
"""
|
|
252
|
+
"""
|
|
253
|
+
)
|
|
253
254
|
# Add index for faster lookups if it doesn't exist
|
|
254
|
-
self.conn.execute(
|
|
255
|
+
self.conn.execute(
|
|
256
|
+
"""
|
|
255
257
|
CREATE INDEX IF NOT EXISTS idx_cache_hash ON cache(hash)
|
|
256
|
-
"""
|
|
258
|
+
"""
|
|
259
|
+
)
|
|
257
260
|
self.conn.commit()
|
|
258
261
|
|
|
259
262
|
def _load_openai_client(self) -> None:
|
|
@@ -275,7 +278,7 @@ class VectorCache:
|
|
|
275
278
|
tensor_parallel_size = cast(int, self.config["vllm_tensor_parallel_size"])
|
|
276
279
|
dtype = cast(str, self.config["vllm_dtype"])
|
|
277
280
|
trust_remote_code = cast(bool, self.config["vllm_trust_remote_code"])
|
|
278
|
-
max_model_len = cast(
|
|
281
|
+
max_model_len = cast(int | None, self.config["vllm_max_model_len"])
|
|
279
282
|
|
|
280
283
|
vllm_kwargs = {
|
|
281
284
|
"model": self.url_or_model,
|
|
@@ -312,8 +315,7 @@ class VectorCache:
|
|
|
312
315
|
f"4. Ensure no other processes are using GPU memory during initialization\n"
|
|
313
316
|
f"Original error: {e}"
|
|
314
317
|
) from e
|
|
315
|
-
|
|
316
|
-
raise
|
|
318
|
+
raise
|
|
317
319
|
elif self.backend == "transformers":
|
|
318
320
|
import torch # type: ignore[import-not-found] # noqa: F401
|
|
319
321
|
from transformers import ( # type: ignore[import-not-found]
|
|
@@ -345,29 +347,28 @@ class VectorCache:
|
|
|
345
347
|
def _get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
346
348
|
"""Get embeddings using the configured backend."""
|
|
347
349
|
assert isinstance(texts, list), "texts must be a list"
|
|
348
|
-
assert all(
|
|
349
|
-
|
|
350
|
-
)
|
|
350
|
+
assert all(
|
|
351
|
+
isinstance(t, str) for t in texts
|
|
352
|
+
), "all elements in texts must be strings"
|
|
351
353
|
if self.backend == "openai":
|
|
352
354
|
return self._get_openai_embeddings(texts)
|
|
353
|
-
|
|
355
|
+
if self.backend == "vllm":
|
|
354
356
|
return self._get_vllm_embeddings(texts)
|
|
355
|
-
|
|
357
|
+
if self.backend == "transformers":
|
|
356
358
|
return self._get_transformers_embeddings(texts)
|
|
357
|
-
|
|
358
|
-
raise ValueError(f"Unsupported backend: {self.backend}")
|
|
359
|
+
raise ValueError(f"Unsupported backend: {self.backend}")
|
|
359
360
|
|
|
360
361
|
def _get_openai_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
361
362
|
"""Get embeddings using OpenAI API."""
|
|
362
363
|
assert isinstance(texts, list), "texts must be a list"
|
|
363
|
-
assert all(
|
|
364
|
-
|
|
365
|
-
)
|
|
364
|
+
assert all(
|
|
365
|
+
isinstance(t, str) for t in texts
|
|
366
|
+
), "all elements in texts must be strings"
|
|
366
367
|
# Assert valid model_name for OpenAI backend
|
|
367
368
|
model_name = self.config["model_name"]
|
|
368
|
-
assert
|
|
369
|
-
|
|
370
|
-
)
|
|
369
|
+
assert (
|
|
370
|
+
model_name is not None and model_name.strip()
|
|
371
|
+
), f"Invalid model_name for OpenAI backend: {model_name}. Model name must be provided and non-empty."
|
|
371
372
|
|
|
372
373
|
if self._client is None:
|
|
373
374
|
if self.verbose:
|
|
@@ -385,9 +386,9 @@ class VectorCache:
|
|
|
385
386
|
def _get_vllm_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
386
387
|
"""Get embeddings using vLLM."""
|
|
387
388
|
assert isinstance(texts, list), "texts must be a list"
|
|
388
|
-
assert all(
|
|
389
|
-
|
|
390
|
-
)
|
|
389
|
+
assert all(
|
|
390
|
+
isinstance(t, str) for t in texts
|
|
391
|
+
), "all elements in texts must be strings"
|
|
391
392
|
if self._model is None:
|
|
392
393
|
if self.verbose:
|
|
393
394
|
print("🔧 Loading vLLM model...")
|
|
@@ -402,9 +403,9 @@ class VectorCache:
|
|
|
402
403
|
def _get_transformers_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
403
404
|
"""Get embeddings using transformers directly."""
|
|
404
405
|
assert isinstance(texts, list), "texts must be a list"
|
|
405
|
-
assert all(
|
|
406
|
-
|
|
407
|
-
)
|
|
406
|
+
assert all(
|
|
407
|
+
isinstance(t, str) for t in texts
|
|
408
|
+
), "all elements in texts must be strings"
|
|
408
409
|
if self._model is None:
|
|
409
410
|
if self.verbose:
|
|
410
411
|
print("🔧 Loading Transformers model...")
|
|
@@ -464,13 +465,12 @@ class VectorCache:
|
|
|
464
465
|
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
|
|
465
466
|
if left_padding:
|
|
466
467
|
return last_hidden_states[:, -1]
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
]
|
|
468
|
+
sequence_lengths = attention_mask.sum(dim=1) - 1
|
|
469
|
+
batch_size = last_hidden_states.shape[0]
|
|
470
|
+
return last_hidden_states[
|
|
471
|
+
torch.arange(batch_size, device=last_hidden_states.device),
|
|
472
|
+
sequence_lengths,
|
|
473
|
+
]
|
|
474
474
|
|
|
475
475
|
def _hash_text(self, text: str) -> str:
|
|
476
476
|
return hashlib.sha1(text.encode("utf-8")).hexdigest()
|
|
@@ -486,8 +486,7 @@ class VectorCache:
|
|
|
486
486
|
try:
|
|
487
487
|
if params is None:
|
|
488
488
|
return self.conn.execute(query)
|
|
489
|
-
|
|
490
|
-
return self.conn.execute(query, params)
|
|
489
|
+
return self.conn.execute(query, params)
|
|
491
490
|
|
|
492
491
|
except sqlite3.OperationalError as e:
|
|
493
492
|
last_exception = e
|
|
@@ -502,9 +501,8 @@ class VectorCache:
|
|
|
502
501
|
|
|
503
502
|
time.sleep(delay)
|
|
504
503
|
continue
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
raise
|
|
504
|
+
# Re-raise if not a lock error or max retries exceeded
|
|
505
|
+
raise
|
|
508
506
|
except Exception:
|
|
509
507
|
# Re-raise any other exceptions
|
|
510
508
|
raise
|
|
@@ -524,9 +522,9 @@ class VectorCache:
|
|
|
524
522
|
computing missing embeddings.
|
|
525
523
|
"""
|
|
526
524
|
assert isinstance(texts, list), "texts must be a list"
|
|
527
|
-
assert all(
|
|
528
|
-
|
|
529
|
-
)
|
|
525
|
+
assert all(
|
|
526
|
+
isinstance(t, str) for t in texts
|
|
527
|
+
), "all elements in texts must be strings"
|
|
530
528
|
if not texts:
|
|
531
529
|
return np.empty((0, 0), dtype=np.float32)
|
|
532
530
|
t = time()
|
|
@@ -554,11 +552,11 @@ class VectorCache:
|
|
|
554
552
|
# Determine which texts are missing
|
|
555
553
|
if cache:
|
|
556
554
|
missing_items: list[tuple[str, str]] = [
|
|
557
|
-
(t, h) for t, h in zip(texts, hashes) if h not in hit_map
|
|
555
|
+
(t, h) for t, h in zip(texts, hashes, strict=False) if h not in hit_map
|
|
558
556
|
]
|
|
559
557
|
else:
|
|
560
558
|
missing_items: list[tuple[str, str]] = [
|
|
561
|
-
(t, h) for t, h in zip(texts, hashes)
|
|
559
|
+
(t, h) for t, h in zip(texts, hashes, strict=False)
|
|
562
560
|
]
|
|
563
561
|
|
|
564
562
|
if missing_items:
|
|
@@ -608,7 +606,7 @@ class VectorCache:
|
|
|
608
606
|
|
|
609
607
|
# Prepare batch data for immediate insert
|
|
610
608
|
batch_data: list[tuple[str, str, bytes]] = []
|
|
611
|
-
for (text, h), vec in zip(batch_items, batch_embeds):
|
|
609
|
+
for (text, h), vec in zip(batch_items, batch_embeds, strict=False):
|
|
612
610
|
arr = np.asarray(vec, dtype=np.float32)
|
|
613
611
|
batch_data.append((h, text, arr.tobytes()))
|
|
614
612
|
hit_map[h] = arr
|
|
@@ -640,9 +638,9 @@ class VectorCache:
|
|
|
640
638
|
|
|
641
639
|
def __call__(self, texts: list[str], cache: bool = True) -> np.ndarray:
|
|
642
640
|
assert isinstance(texts, list), "texts must be a list"
|
|
643
|
-
assert all(
|
|
644
|
-
|
|
645
|
-
)
|
|
641
|
+
assert all(
|
|
642
|
+
isinstance(t, str) for t in texts
|
|
643
|
+
), "all elements in texts must be strings"
|
|
646
644
|
return self.embeds(texts, cache)
|
|
647
645
|
|
|
648
646
|
def _bulk_insert(self, data: list[tuple[str, str, bytes]]) -> None:
|
|
@@ -662,7 +660,7 @@ class VectorCache:
|
|
|
662
660
|
|
|
663
661
|
for attempt in range(max_retries + 1):
|
|
664
662
|
try:
|
|
665
|
-
|
|
663
|
+
self.conn.executemany(
|
|
666
664
|
"INSERT OR IGNORE INTO cache (hash, text, embedding) VALUES (?, ?, ?)",
|
|
667
665
|
data,
|
|
668
666
|
)
|
|
@@ -688,9 +686,8 @@ class VectorCache:
|
|
|
688
686
|
|
|
689
687
|
time.sleep(delay)
|
|
690
688
|
continue
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
raise
|
|
689
|
+
# Re-raise if not a lock error or max retries exceeded
|
|
690
|
+
raise
|
|
694
691
|
except Exception:
|
|
695
692
|
# Re-raise any other exceptions
|
|
696
693
|
raise
|
|
@@ -723,12 +720,11 @@ class VectorCache:
|
|
|
723
720
|
|
|
724
721
|
time.sleep(delay)
|
|
725
722
|
continue
|
|
726
|
-
|
|
727
|
-
raise
|
|
723
|
+
raise
|
|
728
724
|
except Exception:
|
|
729
725
|
raise
|
|
730
726
|
|
|
731
|
-
def get_config(self) ->
|
|
727
|
+
def get_config(self) -> dict[str, Any]:
|
|
732
728
|
"""Get current configuration."""
|
|
733
729
|
return {
|
|
734
730
|
"url_or_model": self.url_or_model,
|
llm_utils/vector_cache/types.py
CHANGED
|
@@ -1,15 +1,17 @@
|
|
|
1
1
|
"""Type definitions for the embed_cache package."""
|
|
2
2
|
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
4
|
+
|
|
4
5
|
import numpy as np
|
|
5
6
|
from numpy.typing import NDArray
|
|
6
7
|
|
|
8
|
+
|
|
7
9
|
# Type aliases
|
|
8
|
-
TextList =
|
|
10
|
+
TextList = list[str]
|
|
9
11
|
EmbeddingArray = NDArray[np.float32]
|
|
10
|
-
EmbeddingList =
|
|
11
|
-
CacheStats =
|
|
12
|
+
EmbeddingList = list[list[float]]
|
|
13
|
+
CacheStats = dict[str, int]
|
|
12
14
|
ModelIdentifier = str # Either URL or model name/path
|
|
13
15
|
|
|
14
16
|
# For backwards compatibility
|
|
15
|
-
Embeddings = Union[EmbeddingArray, EmbeddingList]
|
|
17
|
+
Embeddings = Union[EmbeddingArray, EmbeddingList]
|