speedy-utils 1.1.12__tar.gz → 1.1.14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/PKG-INFO +2 -1
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/pyproject.toml +2 -1
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/__init__.py +4 -1
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/lm/async_lm/async_lm.py +2 -1
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/lm/async_lm/async_lm_base.py +6 -6
- speedy_utils-1.1.14/src/llm_utils/vector_cache/__init__.py +25 -0
- speedy_utils-1.1.14/src/llm_utils/vector_cache/cli.py +200 -0
- speedy_utils-1.1.14/src/llm_utils/vector_cache/core.py +557 -0
- speedy_utils-1.1.14/src/llm_utils/vector_cache/types.py +15 -0
- speedy_utils-1.1.14/src/llm_utils/vector_cache/utils.py +42 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/__init__.py +1 -1
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/common/logger.py +2 -2
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/common/utils_io.py +2 -2
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/common/utils_print.py +5 -5
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/multi_worker/process.py +4 -4
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/multi_worker/thread.py +4 -4
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/README.md +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/chat_format/__init__.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/chat_format/display.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/chat_format/transform.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/chat_format/utils.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/group_messages.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/lm/__init__.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/lm/async_lm/__init__.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/lm/async_lm/_utils.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/lm/async_lm/async_llm_task.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/lm/async_lm/lm_specific.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/lm/openai_memoize.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/lm/utils.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/scripts/README.md +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/scripts/vllm_load_balancer.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/scripts/vllm_serve.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/all.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/common/__init__.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/common/clock.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/common/function_decorator.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/common/notebook_utils.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/common/report_manager.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/common/utils_cache.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/common/utils_misc.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/multi_worker/__init__.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/scripts/__init__.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/scripts/mpython.py +0 -0
- {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/scripts/openapi_client_codegen.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: speedy-utils
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.14
|
|
4
4
|
Summary: Fast and easy-to-use package for data science
|
|
5
5
|
Author: AnhVTH
|
|
6
6
|
Author-email: anhvth.226@gmail.com
|
|
@@ -25,6 +25,7 @@ Requires-Dist: jupyterlab
|
|
|
25
25
|
Requires-Dist: loguru
|
|
26
26
|
Requires-Dist: matplotlib
|
|
27
27
|
Requires-Dist: numpy
|
|
28
|
+
Requires-Dist: openai (>=1.106.0,<2.0.0)
|
|
28
29
|
Requires-Dist: packaging (>=23.2,<25)
|
|
29
30
|
Requires-Dist: pandas
|
|
30
31
|
Requires-Dist: pydantic
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "speedy-utils"
|
|
3
|
-
version = "1.1.
|
|
3
|
+
version = "1.1.14"
|
|
4
4
|
description = "Fast and easy-to-use package for data science"
|
|
5
5
|
authors = ["AnhVTH <anhvth.226@gmail.com>"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -58,6 +58,7 @@ json-repair = ">=0.25.0,<0.31.0"
|
|
|
58
58
|
fastprogress = "*"
|
|
59
59
|
freezegun = "^1.5.1"
|
|
60
60
|
packaging = ">=23.2,<25"
|
|
61
|
+
openai = "^1.106.0"
|
|
61
62
|
|
|
62
63
|
[tool.poetry.scripts]
|
|
63
64
|
mpython = "speedy_utils.scripts.mpython:main"
|
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
from llm_utils.lm.openai_memoize import MOpenAI
|
|
2
|
+
from llm_utils.vector_cache import VectorCache
|
|
3
|
+
|
|
2
4
|
from .chat_format import (
|
|
3
5
|
build_chatml_input,
|
|
4
6
|
display_chat_messages_as_html,
|
|
@@ -24,5 +26,6 @@ __all__ = [
|
|
|
24
26
|
"display_chat_messages_as_html",
|
|
25
27
|
"AsyncLM",
|
|
26
28
|
"AsyncLLMTask",
|
|
27
|
-
"MOpenAI"
|
|
29
|
+
"MOpenAI",
|
|
30
|
+
"VectorCache",
|
|
28
31
|
]
|
|
@@ -5,6 +5,7 @@ from typing import (
|
|
|
5
5
|
Literal,
|
|
6
6
|
Optional,
|
|
7
7
|
Type,
|
|
8
|
+
Union,
|
|
8
9
|
cast,
|
|
9
10
|
)
|
|
10
11
|
|
|
@@ -49,7 +50,7 @@ class AsyncLM(AsyncLMBase):
|
|
|
49
50
|
temperature: float = 0.0,
|
|
50
51
|
max_tokens: int = 2_000,
|
|
51
52
|
host: str = "localhost",
|
|
52
|
-
port: Optional[int
|
|
53
|
+
port: Optional[Union[int, str]] = None,
|
|
53
54
|
base_url: Optional[str] = None,
|
|
54
55
|
api_key: Optional[str] = None,
|
|
55
56
|
cache: bool = True,
|
|
@@ -41,7 +41,7 @@ class AsyncLMBase:
|
|
|
41
41
|
self,
|
|
42
42
|
*,
|
|
43
43
|
host: str = "localhost",
|
|
44
|
-
port: Optional[int
|
|
44
|
+
port: Optional[Union[int, str]] = None,
|
|
45
45
|
base_url: Optional[str] = None,
|
|
46
46
|
api_key: Optional[str] = None,
|
|
47
47
|
cache: bool = True,
|
|
@@ -81,8 +81,8 @@ class AsyncLMBase:
|
|
|
81
81
|
async def __call__( # type: ignore
|
|
82
82
|
self,
|
|
83
83
|
*,
|
|
84
|
-
prompt: str
|
|
85
|
-
messages: RawMsgs
|
|
84
|
+
prompt: Optional[str] = ...,
|
|
85
|
+
messages: Optional[RawMsgs] = ...,
|
|
86
86
|
response_format: type[str] = str,
|
|
87
87
|
return_openai_response: bool = ...,
|
|
88
88
|
**kwargs: Any,
|
|
@@ -92,8 +92,8 @@ class AsyncLMBase:
|
|
|
92
92
|
async def __call__(
|
|
93
93
|
self,
|
|
94
94
|
*,
|
|
95
|
-
prompt: str
|
|
96
|
-
messages: RawMsgs
|
|
95
|
+
prompt: Optional[str] = ...,
|
|
96
|
+
messages: Optional[RawMsgs] = ...,
|
|
97
97
|
response_format: Type[TModel],
|
|
98
98
|
return_openai_response: bool = ...,
|
|
99
99
|
**kwargs: Any,
|
|
@@ -137,7 +137,7 @@ class AsyncLMBase:
|
|
|
137
137
|
@staticmethod
|
|
138
138
|
def _parse_output(
|
|
139
139
|
raw_response: Any, response_format: Union[type[str], Type[BaseModel]]
|
|
140
|
-
) -> str
|
|
140
|
+
) -> Union[str, BaseModel]:
|
|
141
141
|
if hasattr(raw_response, "model_dump"):
|
|
142
142
|
raw_response = raw_response.model_dump()
|
|
143
143
|
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Efficient embedding caching system using vLLM for offline embeddings.
|
|
3
|
+
|
|
4
|
+
This package provides a fast, SQLite-backed caching layer for text embeddings,
|
|
5
|
+
supporting both OpenAI API and local models via vLLM.
|
|
6
|
+
|
|
7
|
+
Classes:
|
|
8
|
+
VectorCache: Main class for embedding computation and caching
|
|
9
|
+
|
|
10
|
+
Example:
|
|
11
|
+
# Using local model
|
|
12
|
+
cache = VectorCache("Qwen/Qwen3-Embedding-0.6B")
|
|
13
|
+
embeddings = cache.embeds(["Hello world", "How are you?"])
|
|
14
|
+
|
|
15
|
+
# Using OpenAI API
|
|
16
|
+
cache = VectorCache("https://api.openai.com/v1")
|
|
17
|
+
embeddings = cache.embeds(["Hello world", "How are you?"])
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from .core import VectorCache
|
|
21
|
+
from .utils import get_default_cache_path, validate_model_name, estimate_cache_size
|
|
22
|
+
|
|
23
|
+
__version__ = "0.1.0"
|
|
24
|
+
__author__ = "AnhVTH <anhvth.226@gmail.com>"
|
|
25
|
+
__all__ = ["VectorCache", "get_default_cache_path", "validate_model_name", "estimate_cache_size"]
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Command-line interface for embed_cache package."""
|
|
3
|
+
|
|
4
|
+
import argparse
|
|
5
|
+
import json
|
|
6
|
+
import sys
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import List
|
|
9
|
+
|
|
10
|
+
from llm_utils.vector_cache import VectorCache, estimate_cache_size, validate_model_name
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def main():
|
|
14
|
+
"""Main CLI entry point."""
|
|
15
|
+
parser = argparse.ArgumentParser(
|
|
16
|
+
description="Command-line interface for embed_cache package"
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
|
20
|
+
|
|
21
|
+
# Embed command
|
|
22
|
+
embed_parser = subparsers.add_parser("embed", help="Generate embeddings for texts")
|
|
23
|
+
embed_parser.add_argument("model", help="Model name or API URL")
|
|
24
|
+
embed_parser.add_argument("--texts", nargs="+", help="Texts to embed")
|
|
25
|
+
embed_parser.add_argument("--file", help="File containing texts (one per line)")
|
|
26
|
+
embed_parser.add_argument("--output", help="Output file for embeddings (JSON)")
|
|
27
|
+
embed_parser.add_argument(
|
|
28
|
+
"--cache-db", default="embed_cache.sqlite", help="Cache database path"
|
|
29
|
+
)
|
|
30
|
+
embed_parser.add_argument(
|
|
31
|
+
"--backend",
|
|
32
|
+
choices=["vllm", "transformers", "openai", "auto"],
|
|
33
|
+
default="auto",
|
|
34
|
+
help="Backend to use",
|
|
35
|
+
)
|
|
36
|
+
embed_parser.add_argument(
|
|
37
|
+
"--gpu-memory",
|
|
38
|
+
type=float,
|
|
39
|
+
default=0.5,
|
|
40
|
+
help="GPU memory utilization for vLLM (0.0-1.0)",
|
|
41
|
+
)
|
|
42
|
+
embed_parser.add_argument(
|
|
43
|
+
"--batch-size", type=int, default=32, help="Batch size for transformers"
|
|
44
|
+
)
|
|
45
|
+
embed_parser.add_argument(
|
|
46
|
+
"--verbose", action="store_true", help="Enable verbose output"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Cache stats command
|
|
50
|
+
stats_parser = subparsers.add_parser("stats", help="Show cache statistics")
|
|
51
|
+
stats_parser.add_argument(
|
|
52
|
+
"--cache-db", default="embed_cache.sqlite", help="Cache database path"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Clear cache command
|
|
56
|
+
clear_parser = subparsers.add_parser("clear", help="Clear cache")
|
|
57
|
+
clear_parser.add_argument(
|
|
58
|
+
"--cache-db", default="embed_cache.sqlite", help="Cache database path"
|
|
59
|
+
)
|
|
60
|
+
clear_parser.add_argument(
|
|
61
|
+
"--confirm", action="store_true", help="Skip confirmation prompt"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Validate model command
|
|
65
|
+
validate_parser = subparsers.add_parser("validate", help="Validate model name")
|
|
66
|
+
validate_parser.add_argument("model", help="Model name to validate")
|
|
67
|
+
|
|
68
|
+
# Estimate command
|
|
69
|
+
estimate_parser = subparsers.add_parser("estimate", help="Estimate cache size")
|
|
70
|
+
estimate_parser.add_argument("num_texts", type=int, help="Number of texts")
|
|
71
|
+
estimate_parser.add_argument(
|
|
72
|
+
"--embed-dim", type=int, default=1024, help="Embedding dimension"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
args = parser.parse_args()
|
|
76
|
+
|
|
77
|
+
if not args.command:
|
|
78
|
+
parser.print_help()
|
|
79
|
+
return
|
|
80
|
+
|
|
81
|
+
try:
|
|
82
|
+
if args.command == "embed":
|
|
83
|
+
handle_embed(args)
|
|
84
|
+
elif args.command == "stats":
|
|
85
|
+
handle_stats(args)
|
|
86
|
+
elif args.command == "clear":
|
|
87
|
+
handle_clear(args)
|
|
88
|
+
elif args.command == "validate":
|
|
89
|
+
handle_validate(args)
|
|
90
|
+
elif args.command == "estimate":
|
|
91
|
+
handle_estimate(args)
|
|
92
|
+
except Exception as e:
|
|
93
|
+
print(f"Error: {e}", file=sys.stderr)
|
|
94
|
+
sys.exit(1)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def handle_embed(args):
|
|
98
|
+
"""Handle embed command."""
|
|
99
|
+
# Get texts
|
|
100
|
+
texts = []
|
|
101
|
+
if args.texts:
|
|
102
|
+
texts.extend(args.texts)
|
|
103
|
+
|
|
104
|
+
if args.file:
|
|
105
|
+
file_path = Path(args.file)
|
|
106
|
+
if not file_path.exists():
|
|
107
|
+
raise FileNotFoundError(f"File not found: {args.file}")
|
|
108
|
+
|
|
109
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
110
|
+
texts.extend([line.strip() for line in f if line.strip()])
|
|
111
|
+
|
|
112
|
+
if not texts:
|
|
113
|
+
raise ValueError("No texts provided. Use --texts or --file")
|
|
114
|
+
|
|
115
|
+
print(f"Embedding {len(texts)} texts using model: {args.model}")
|
|
116
|
+
|
|
117
|
+
# Initialize cache and get embeddings
|
|
118
|
+
cache = VectorCache(args.model, db_path=args.cache_db)
|
|
119
|
+
embeddings = cache.embeds(texts)
|
|
120
|
+
|
|
121
|
+
print(f"Generated embeddings with shape: {embeddings.shape}")
|
|
122
|
+
|
|
123
|
+
# Output results
|
|
124
|
+
if args.output:
|
|
125
|
+
output_data = {
|
|
126
|
+
"texts": texts,
|
|
127
|
+
"embeddings": embeddings.tolist(),
|
|
128
|
+
"shape": list(embeddings.shape),
|
|
129
|
+
"model": args.model,
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
with open(args.output, "w", encoding="utf-8") as f:
|
|
133
|
+
json.dump(output_data, f, indent=2)
|
|
134
|
+
|
|
135
|
+
print(f"Results saved to: {args.output}")
|
|
136
|
+
else:
|
|
137
|
+
print(f"Embeddings shape: {embeddings.shape}")
|
|
138
|
+
print(f"Sample embedding (first 5 dims): {embeddings[0][:5].tolist()}")
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def handle_stats(args):
|
|
142
|
+
"""Handle stats command."""
|
|
143
|
+
cache_path = Path(args.cache_db)
|
|
144
|
+
if not cache_path.exists():
|
|
145
|
+
print(f"Cache database not found: {args.cache_db}")
|
|
146
|
+
return
|
|
147
|
+
|
|
148
|
+
cache = VectorCache("dummy", db_path=args.cache_db)
|
|
149
|
+
stats = cache.get_cache_stats()
|
|
150
|
+
|
|
151
|
+
print("Cache Statistics:")
|
|
152
|
+
print(f" Database: {args.cache_db}")
|
|
153
|
+
print(f" Total cached embeddings: {stats['total_cached']}")
|
|
154
|
+
print(f" Database size: {cache_path.stat().st_size / (1024 * 1024):.2f} MB")
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def handle_clear(args):
|
|
158
|
+
"""Handle clear command."""
|
|
159
|
+
cache_path = Path(args.cache_db)
|
|
160
|
+
if not cache_path.exists():
|
|
161
|
+
print(f"Cache database not found: {args.cache_db}")
|
|
162
|
+
return
|
|
163
|
+
|
|
164
|
+
if not args.confirm:
|
|
165
|
+
response = input(
|
|
166
|
+
f"Are you sure you want to clear cache at {args.cache_db}? [y/N]: "
|
|
167
|
+
)
|
|
168
|
+
if response.lower() != "y":
|
|
169
|
+
print("Cancelled.")
|
|
170
|
+
return
|
|
171
|
+
|
|
172
|
+
cache = VectorCache("dummy", db_path=args.cache_db)
|
|
173
|
+
stats_before = cache.get_cache_stats()
|
|
174
|
+
cache.clear_cache()
|
|
175
|
+
|
|
176
|
+
print(
|
|
177
|
+
f"Cleared {stats_before['total_cached']} cached embeddings from {args.cache_db}"
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def handle_validate(args):
|
|
182
|
+
"""Handle validate command."""
|
|
183
|
+
is_valid = validate_model_name(args.model)
|
|
184
|
+
if is_valid:
|
|
185
|
+
print(f"✓ Valid model: {args.model}")
|
|
186
|
+
else:
|
|
187
|
+
print(f"✗ Invalid model: {args.model}")
|
|
188
|
+
sys.exit(1)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def handle_estimate(args):
|
|
192
|
+
"""Handle estimate command."""
|
|
193
|
+
size_estimate = estimate_cache_size(args.num_texts, args.embed_dim)
|
|
194
|
+
print(
|
|
195
|
+
f"Estimated cache size for {args.num_texts} texts ({args.embed_dim}D embeddings): {size_estimate}"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
if __name__ == "__main__":
|
|
200
|
+
main()
|
|
@@ -0,0 +1,557 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import os
|
|
5
|
+
import sqlite3
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from time import time
|
|
8
|
+
from typing import Any, Dict, Literal, Optional, cast
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class VectorCache:
|
|
14
|
+
"""
|
|
15
|
+
A caching layer for text embeddings with support for multiple backends.
|
|
16
|
+
|
|
17
|
+
Examples:
|
|
18
|
+
# OpenAI API
|
|
19
|
+
from llm_utils import VectorCache
|
|
20
|
+
cache = VectorCache("https://api.openai.com/v1", api_key="your-key")
|
|
21
|
+
embeddings = cache.embeds(["Hello world", "How are you?"])
|
|
22
|
+
|
|
23
|
+
# Custom OpenAI-compatible server (auto-detects model)
|
|
24
|
+
cache = VectorCache("http://localhost:8000/v1", api_key="abc")
|
|
25
|
+
|
|
26
|
+
# Transformers (Sentence Transformers)
|
|
27
|
+
cache = VectorCache("sentence-transformers/all-MiniLM-L6-v2")
|
|
28
|
+
|
|
29
|
+
# vLLM (local model)
|
|
30
|
+
cache = VectorCache("/path/to/model")
|
|
31
|
+
|
|
32
|
+
# Explicit backend specification
|
|
33
|
+
cache = VectorCache("model-name", backend="transformers")
|
|
34
|
+
|
|
35
|
+
# Lazy loading (default: True) - load model only when needed
|
|
36
|
+
cache = VectorCache("model-name", lazy=True)
|
|
37
|
+
|
|
38
|
+
# Eager loading - load model immediately
|
|
39
|
+
cache = VectorCache("model-name", lazy=False)
|
|
40
|
+
"""
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
url_or_model: str,
|
|
44
|
+
backend: Optional[Literal["vllm", "transformers", "openai"]] = None,
|
|
45
|
+
embed_size: Optional[int] = None,
|
|
46
|
+
db_path: Optional[str] = None,
|
|
47
|
+
# OpenAI API parameters
|
|
48
|
+
api_key: Optional[str] = "abc",
|
|
49
|
+
model_name: Optional[str] = None,
|
|
50
|
+
# vLLM parameters
|
|
51
|
+
vllm_gpu_memory_utilization: float = 0.5,
|
|
52
|
+
vllm_tensor_parallel_size: int = 1,
|
|
53
|
+
vllm_dtype: str = "auto",
|
|
54
|
+
vllm_trust_remote_code: bool = False,
|
|
55
|
+
vllm_max_model_len: Optional[int] = None,
|
|
56
|
+
# Transformers parameters
|
|
57
|
+
transformers_device: str = "auto",
|
|
58
|
+
transformers_batch_size: int = 32,
|
|
59
|
+
transformers_normalize_embeddings: bool = True,
|
|
60
|
+
transformers_trust_remote_code: bool = False,
|
|
61
|
+
# SQLite parameters
|
|
62
|
+
sqlite_chunk_size: int = 999,
|
|
63
|
+
sqlite_cache_size: int = 10000,
|
|
64
|
+
sqlite_mmap_size: int = 268435456,
|
|
65
|
+
# Other parameters
|
|
66
|
+
verbose: bool = True,
|
|
67
|
+
lazy: bool = True,
|
|
68
|
+
) -> None:
|
|
69
|
+
self.url_or_model = url_or_model
|
|
70
|
+
self.embed_size = embed_size
|
|
71
|
+
self.verbose = verbose
|
|
72
|
+
self.lazy = lazy
|
|
73
|
+
|
|
74
|
+
self.backend = self._determine_backend(backend)
|
|
75
|
+
if self.verbose and backend is None:
|
|
76
|
+
print(f"Auto-detected backend: {self.backend}")
|
|
77
|
+
|
|
78
|
+
# Store all configuration parameters
|
|
79
|
+
self.config = {
|
|
80
|
+
# OpenAI
|
|
81
|
+
"api_key": api_key or os.getenv("OPENAI_API_KEY"),
|
|
82
|
+
"model_name": self._try_infer_model_name(model_name),
|
|
83
|
+
# vLLM
|
|
84
|
+
"vllm_gpu_memory_utilization": vllm_gpu_memory_utilization,
|
|
85
|
+
"vllm_tensor_parallel_size": vllm_tensor_parallel_size,
|
|
86
|
+
"vllm_dtype": vllm_dtype,
|
|
87
|
+
"vllm_trust_remote_code": vllm_trust_remote_code,
|
|
88
|
+
"vllm_max_model_len": vllm_max_model_len,
|
|
89
|
+
# Transformers
|
|
90
|
+
"transformers_device": transformers_device,
|
|
91
|
+
"transformers_batch_size": transformers_batch_size,
|
|
92
|
+
"transformers_normalize_embeddings": transformers_normalize_embeddings,
|
|
93
|
+
"transformers_trust_remote_code": transformers_trust_remote_code,
|
|
94
|
+
# SQLite
|
|
95
|
+
"sqlite_chunk_size": sqlite_chunk_size,
|
|
96
|
+
"sqlite_cache_size": sqlite_cache_size,
|
|
97
|
+
"sqlite_mmap_size": sqlite_mmap_size,
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
# Auto-detect model_name for OpenAI if using custom URL and default model
|
|
101
|
+
if (self.backend == "openai" and
|
|
102
|
+
model_name == "text-embedding-3-small" and
|
|
103
|
+
self.url_or_model != "https://api.openai.com/v1"):
|
|
104
|
+
if self.verbose:
|
|
105
|
+
print(f"Attempting to auto-detect model from {self.url_or_model}...")
|
|
106
|
+
try:
|
|
107
|
+
import openai
|
|
108
|
+
client = openai.OpenAI(
|
|
109
|
+
base_url=self.url_or_model,
|
|
110
|
+
api_key=self.config["api_key"]
|
|
111
|
+
)
|
|
112
|
+
models = client.models.list()
|
|
113
|
+
if models.data:
|
|
114
|
+
detected_model = models.data[0].id
|
|
115
|
+
self.config["model_name"] = detected_model
|
|
116
|
+
model_name = detected_model # Update for db_path computation
|
|
117
|
+
if self.verbose:
|
|
118
|
+
print(f"Auto-detected model: {detected_model}")
|
|
119
|
+
else:
|
|
120
|
+
if self.verbose:
|
|
121
|
+
print("No models found, using default model")
|
|
122
|
+
except Exception as e:
|
|
123
|
+
if self.verbose:
|
|
124
|
+
print(f"Model auto-detection failed: {e}, using default model")
|
|
125
|
+
# Fallback to default if auto-detection fails
|
|
126
|
+
pass
|
|
127
|
+
|
|
128
|
+
# Set default db_path if not provided
|
|
129
|
+
if db_path is None:
|
|
130
|
+
if self.backend == "openai":
|
|
131
|
+
model_id = self.config["model_name"] or "openai-default"
|
|
132
|
+
else:
|
|
133
|
+
model_id = self.url_or_model
|
|
134
|
+
safe_name = hashlib.sha1(model_id.encode("utf-8")).hexdigest()[:16]
|
|
135
|
+
self.db_path = Path.home() / ".cache" / "embed" / f"{self.backend}_{safe_name}.sqlite"
|
|
136
|
+
else:
|
|
137
|
+
self.db_path = Path(db_path)
|
|
138
|
+
|
|
139
|
+
# Ensure the directory exists
|
|
140
|
+
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
141
|
+
|
|
142
|
+
self.conn = sqlite3.connect(self.db_path)
|
|
143
|
+
self._optimize_connection()
|
|
144
|
+
self._ensure_schema()
|
|
145
|
+
self._model = None # Lazy loading
|
|
146
|
+
self._client = None # For OpenAI client
|
|
147
|
+
|
|
148
|
+
# Load model/client if not lazy
|
|
149
|
+
if not self.lazy:
|
|
150
|
+
if self.backend == "openai":
|
|
151
|
+
self._load_openai_client()
|
|
152
|
+
elif self.backend in ["vllm", "transformers"]:
|
|
153
|
+
self._load_model()
|
|
154
|
+
|
|
155
|
+
def _determine_backend(self, backend: Optional[Literal["vllm", "transformers", "openai"]]) -> str:
|
|
156
|
+
"""Determine the appropriate backend based on url_or_model and user preference."""
|
|
157
|
+
if backend is not None:
|
|
158
|
+
valid_backends = ["vllm", "transformers", "openai"]
|
|
159
|
+
if backend not in valid_backends:
|
|
160
|
+
raise ValueError(f"Invalid backend '{backend}'. Must be one of: {valid_backends}")
|
|
161
|
+
return backend
|
|
162
|
+
|
|
163
|
+
if self.url_or_model.startswith("http"):
|
|
164
|
+
return "openai"
|
|
165
|
+
|
|
166
|
+
# Default to vllm for local models
|
|
167
|
+
return "vllm"
|
|
168
|
+
def _try_infer_model_name(self, model_name: Optional[str]) -> Optional[str]:
|
|
169
|
+
"""Infer model name for OpenAI backend if not explicitly provided."""
|
|
170
|
+
# if self.backend != "openai":
|
|
171
|
+
# return model_name
|
|
172
|
+
if model_name:
|
|
173
|
+
return model_name
|
|
174
|
+
if 'https://' in self.url_or_model:
|
|
175
|
+
model_name = "text-embedding-3-small"
|
|
176
|
+
if 'http://localhost' in self.url_or_model:
|
|
177
|
+
from openai import OpenAI
|
|
178
|
+
client = OpenAI(base_url=self.url_or_model, api_key='abc')
|
|
179
|
+
model_name = client.models.list().data[0].id
|
|
180
|
+
|
|
181
|
+
# Default model name
|
|
182
|
+
print('Infer model name:', model_name)
|
|
183
|
+
return model_name
|
|
184
|
+
def _optimize_connection(self) -> None:
|
|
185
|
+
"""Optimize SQLite connection for bulk operations."""
|
|
186
|
+
# Performance optimizations for bulk operations
|
|
187
|
+
self.conn.execute(
|
|
188
|
+
"PRAGMA journal_mode=WAL"
|
|
189
|
+
) # Write-Ahead Logging for better concurrency
|
|
190
|
+
self.conn.execute("PRAGMA synchronous=NORMAL") # Faster writes, still safe
|
|
191
|
+
self.conn.execute(f"PRAGMA cache_size={self.config['sqlite_cache_size']}") # Configurable cache
|
|
192
|
+
self.conn.execute("PRAGMA temp_store=MEMORY") # Use memory for temp storage
|
|
193
|
+
self.conn.execute(f"PRAGMA mmap_size={self.config['sqlite_mmap_size']}") # Configurable memory mapping
|
|
194
|
+
|
|
195
|
+
def _ensure_schema(self) -> None:
|
|
196
|
+
self.conn.execute("""
|
|
197
|
+
CREATE TABLE IF NOT EXISTS cache (
|
|
198
|
+
hash TEXT PRIMARY KEY,
|
|
199
|
+
text TEXT,
|
|
200
|
+
embedding BLOB
|
|
201
|
+
)
|
|
202
|
+
""")
|
|
203
|
+
# Add index for faster lookups if it doesn't exist
|
|
204
|
+
self.conn.execute("""
|
|
205
|
+
CREATE INDEX IF NOT EXISTS idx_cache_hash ON cache(hash)
|
|
206
|
+
""")
|
|
207
|
+
self.conn.commit()
|
|
208
|
+
|
|
209
|
+
def _load_openai_client(self) -> None:
|
|
210
|
+
"""Load OpenAI client."""
|
|
211
|
+
import openai
|
|
212
|
+
self._client = openai.OpenAI(
|
|
213
|
+
base_url=self.url_or_model,
|
|
214
|
+
api_key=self.config["api_key"]
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
def _load_model(self) -> None:
|
|
218
|
+
"""Load the model for vLLM or Transformers."""
|
|
219
|
+
if self.backend == "vllm":
|
|
220
|
+
from vllm import LLM
|
|
221
|
+
|
|
222
|
+
gpu_memory_utilization = cast(float, self.config["vllm_gpu_memory_utilization"])
|
|
223
|
+
tensor_parallel_size = cast(int, self.config["vllm_tensor_parallel_size"])
|
|
224
|
+
dtype = cast(str, self.config["vllm_dtype"])
|
|
225
|
+
trust_remote_code = cast(bool, self.config["vllm_trust_remote_code"])
|
|
226
|
+
max_model_len = cast(Optional[int], self.config["vllm_max_model_len"])
|
|
227
|
+
|
|
228
|
+
vllm_kwargs = {
|
|
229
|
+
"model": self.url_or_model,
|
|
230
|
+
"task": "embed",
|
|
231
|
+
"gpu_memory_utilization": gpu_memory_utilization,
|
|
232
|
+
"tensor_parallel_size": tensor_parallel_size,
|
|
233
|
+
"dtype": dtype,
|
|
234
|
+
"trust_remote_code": trust_remote_code,
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
if max_model_len is not None:
|
|
238
|
+
vllm_kwargs["max_model_len"] = max_model_len
|
|
239
|
+
|
|
240
|
+
try:
|
|
241
|
+
self._model = LLM(**vllm_kwargs)
|
|
242
|
+
except (ValueError, AssertionError, RuntimeError) as e:
|
|
243
|
+
error_msg = str(e).lower()
|
|
244
|
+
if ("kv cache" in error_msg and "gpu_memory_utilization" in error_msg) or \
|
|
245
|
+
("memory" in error_msg and ("gpu" in error_msg or "insufficient" in error_msg)) or \
|
|
246
|
+
("free memory" in error_msg and "initial" in error_msg) or \
|
|
247
|
+
("engine core initialization failed" in error_msg):
|
|
248
|
+
raise ValueError(
|
|
249
|
+
f"Insufficient GPU memory for vLLM model initialization. "
|
|
250
|
+
f"Current vllm_gpu_memory_utilization ({gpu_memory_utilization}) may be too low. "
|
|
251
|
+
f"Try one of the following:\n"
|
|
252
|
+
f"1. Increase vllm_gpu_memory_utilization (e.g., 0.5, 0.8, or 0.9)\n"
|
|
253
|
+
f"2. Decrease vllm_max_model_len (e.g., 4096, 8192)\n"
|
|
254
|
+
f"3. Use a smaller model\n"
|
|
255
|
+
f"4. Ensure no other processes are using GPU memory during initialization\n"
|
|
256
|
+
f"Original error: {e}"
|
|
257
|
+
) from e
|
|
258
|
+
else:
|
|
259
|
+
raise
|
|
260
|
+
elif self.backend == "transformers":
|
|
261
|
+
from transformers import AutoTokenizer, AutoModel
|
|
262
|
+
import torch
|
|
263
|
+
|
|
264
|
+
device = self.config["transformers_device"]
|
|
265
|
+
# Handle "auto" device selection - default to CPU for transformers to avoid memory conflicts
|
|
266
|
+
if device == "auto":
|
|
267
|
+
device = "cpu" # Default to CPU to avoid GPU memory conflicts with vLLM
|
|
268
|
+
|
|
269
|
+
tokenizer = AutoTokenizer.from_pretrained(self.url_or_model, padding_side='left', trust_remote_code=self.config["transformers_trust_remote_code"])
|
|
270
|
+
model = AutoModel.from_pretrained(self.url_or_model, trust_remote_code=self.config["transformers_trust_remote_code"])
|
|
271
|
+
|
|
272
|
+
# Move model to device
|
|
273
|
+
model.to(device)
|
|
274
|
+
model.eval()
|
|
275
|
+
|
|
276
|
+
self._model = {"tokenizer": tokenizer, "model": model, "device": device}
|
|
277
|
+
|
|
278
|
+
def _get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
279
|
+
"""Get embeddings using the configured backend."""
|
|
280
|
+
if self.backend == "openai":
|
|
281
|
+
return self._get_openai_embeddings(texts)
|
|
282
|
+
elif self.backend == "vllm":
|
|
283
|
+
return self._get_vllm_embeddings(texts)
|
|
284
|
+
elif self.backend == "transformers":
|
|
285
|
+
return self._get_transformers_embeddings(texts)
|
|
286
|
+
else:
|
|
287
|
+
raise ValueError(f"Unsupported backend: {self.backend}")
|
|
288
|
+
|
|
289
|
+
def _get_openai_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
290
|
+
"""Get embeddings using OpenAI API."""
|
|
291
|
+
# Assert valid model_name for OpenAI backend
|
|
292
|
+
model_name = self.config["model_name"]
|
|
293
|
+
assert model_name is not None and model_name.strip(), f"Invalid model_name for OpenAI backend: {model_name}. Model name must be provided and non-empty."
|
|
294
|
+
|
|
295
|
+
if self._client is None:
|
|
296
|
+
self._load_openai_client()
|
|
297
|
+
|
|
298
|
+
response = self._client.embeddings.create( # type: ignore
|
|
299
|
+
model=model_name,
|
|
300
|
+
input=texts
|
|
301
|
+
)
|
|
302
|
+
embeddings = [item.embedding for item in response.data]
|
|
303
|
+
return embeddings
|
|
304
|
+
|
|
305
|
+
def _get_vllm_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
306
|
+
"""Get embeddings using vLLM."""
|
|
307
|
+
if self._model is None:
|
|
308
|
+
self._load_model()
|
|
309
|
+
|
|
310
|
+
outputs = self._model.embed(texts) # type: ignore
|
|
311
|
+
embeddings = [o.outputs.embedding for o in outputs]
|
|
312
|
+
return embeddings
|
|
313
|
+
|
|
314
|
+
def _get_transformers_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
315
|
+
"""Get embeddings using transformers directly."""
|
|
316
|
+
if self._model is None:
|
|
317
|
+
self._load_model()
|
|
318
|
+
|
|
319
|
+
if not isinstance(self._model, dict):
|
|
320
|
+
raise ValueError("Model not loaded properly for transformers backend")
|
|
321
|
+
|
|
322
|
+
tokenizer = self._model["tokenizer"]
|
|
323
|
+
model = self._model["model"]
|
|
324
|
+
device = self._model["device"]
|
|
325
|
+
|
|
326
|
+
normalize_embeddings = cast(bool, self.config["transformers_normalize_embeddings"])
|
|
327
|
+
|
|
328
|
+
# For now, use a default max_length
|
|
329
|
+
max_length = 8192
|
|
330
|
+
|
|
331
|
+
# Tokenize
|
|
332
|
+
batch_dict = tokenizer(
|
|
333
|
+
texts,
|
|
334
|
+
padding=True,
|
|
335
|
+
truncation=True,
|
|
336
|
+
max_length=max_length,
|
|
337
|
+
return_tensors="pt",
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# Move to device
|
|
341
|
+
batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
|
|
342
|
+
|
|
343
|
+
# Run model
|
|
344
|
+
import torch
|
|
345
|
+
with torch.no_grad():
|
|
346
|
+
outputs = model(**batch_dict)
|
|
347
|
+
|
|
348
|
+
# Apply last token pooling
|
|
349
|
+
embeddings = self._last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
|
350
|
+
|
|
351
|
+
# Normalize if needed
|
|
352
|
+
if normalize_embeddings:
|
|
353
|
+
import torch.nn.functional as F
|
|
354
|
+
embeddings = F.normalize(embeddings, p=2, dim=1)
|
|
355
|
+
|
|
356
|
+
return embeddings.cpu().numpy().tolist()
|
|
357
|
+
|
|
358
|
+
def _last_token_pool(self, last_hidden_states, attention_mask):
|
|
359
|
+
"""Apply last token pooling to get embeddings."""
|
|
360
|
+
import torch
|
|
361
|
+
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
|
362
|
+
if left_padding:
|
|
363
|
+
return last_hidden_states[:, -1]
|
|
364
|
+
else:
|
|
365
|
+
sequence_lengths = attention_mask.sum(dim=1) - 1
|
|
366
|
+
batch_size = last_hidden_states.shape[0]
|
|
367
|
+
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
|
368
|
+
|
|
369
|
+
def _hash_text(self, text: str) -> str:
|
|
370
|
+
return hashlib.sha1(text.encode("utf-8")).hexdigest()
|
|
371
|
+
|
|
372
|
+
def embeds(self, texts: list[str], cache: bool = True) -> np.ndarray:
|
|
373
|
+
"""
|
|
374
|
+
Return embeddings for all texts.
|
|
375
|
+
|
|
376
|
+
If cache=True, compute and cache missing embeddings.
|
|
377
|
+
If cache=False, force recompute all embeddings and update cache.
|
|
378
|
+
|
|
379
|
+
This method processes lookups and embedding generation in chunks to
|
|
380
|
+
handle very large input lists. A tqdm progress bar is shown while
|
|
381
|
+
computing missing embeddings.
|
|
382
|
+
"""
|
|
383
|
+
if not texts:
|
|
384
|
+
return np.empty((0, 0), dtype=np.float32)
|
|
385
|
+
t = time()
|
|
386
|
+
hashes = [self._hash_text(t) for t in texts]
|
|
387
|
+
|
|
388
|
+
# Helper to yield chunks
|
|
389
|
+
def _chunks(lst: list[str], n: int) -> list[list[str]]:
|
|
390
|
+
return [lst[i : i + n] for i in range(0, len(lst), n)]
|
|
391
|
+
|
|
392
|
+
# Fetch known embeddings in bulk with optimized chunk size
|
|
393
|
+
hit_map: dict[str, np.ndarray] = {}
|
|
394
|
+
chunk_size = self.config["sqlite_chunk_size"]
|
|
395
|
+
|
|
396
|
+
# Use bulk lookup with optimized query
|
|
397
|
+
hash_chunks = _chunks(hashes, chunk_size)
|
|
398
|
+
for chunk in hash_chunks:
|
|
399
|
+
placeholders = ",".join("?" * len(chunk))
|
|
400
|
+
rows = self.conn.execute(
|
|
401
|
+
f"SELECT hash, embedding FROM cache WHERE hash IN ({placeholders})",
|
|
402
|
+
chunk,
|
|
403
|
+
).fetchall()
|
|
404
|
+
for h, e in rows:
|
|
405
|
+
hit_map[h] = np.frombuffer(e, dtype=np.float32)
|
|
406
|
+
|
|
407
|
+
# Determine which texts are missing
|
|
408
|
+
if cache:
|
|
409
|
+
missing_items: list[tuple[str, str]] = [
|
|
410
|
+
(t, h) for t, h in zip(texts, hashes) if h not in hit_map
|
|
411
|
+
]
|
|
412
|
+
else:
|
|
413
|
+
missing_items: list[tuple[str, str]] = [
|
|
414
|
+
(t, h) for t, h in zip(texts, hashes)
|
|
415
|
+
]
|
|
416
|
+
|
|
417
|
+
if missing_items:
|
|
418
|
+
if self.verbose:
|
|
419
|
+
print(f"Computing embeddings for {len(missing_items)} missing texts...")
|
|
420
|
+
missing_texts = [t for t, _ in missing_items]
|
|
421
|
+
embeds = self._get_embeddings(missing_texts)
|
|
422
|
+
|
|
423
|
+
# Prepare batch data for bulk insert
|
|
424
|
+
bulk_insert_data: list[tuple[str, str, bytes]] = []
|
|
425
|
+
for (text, h), vec in zip(missing_items, embeds):
|
|
426
|
+
arr = np.asarray(vec, dtype=np.float32)
|
|
427
|
+
bulk_insert_data.append((h, text, arr.tobytes()))
|
|
428
|
+
hit_map[h] = arr
|
|
429
|
+
|
|
430
|
+
self._bulk_insert(bulk_insert_data)
|
|
431
|
+
|
|
432
|
+
# Return embeddings in the original order
|
|
433
|
+
elapsed = time() - t
|
|
434
|
+
if self.verbose:
|
|
435
|
+
print(f"Retrieved {len(texts)} embeddings in {elapsed:.2f} seconds")
|
|
436
|
+
return np.vstack([hit_map[h] for h in hashes])
|
|
437
|
+
|
|
438
|
+
def __call__(self, texts: list[str], cache: bool = True) -> np.ndarray:
|
|
439
|
+
return self.embeds(texts, cache)
|
|
440
|
+
|
|
441
|
+
def _bulk_insert(self, data: list[tuple[str, str, bytes]]) -> None:
|
|
442
|
+
"""Perform bulk insert of embedding data."""
|
|
443
|
+
if not data:
|
|
444
|
+
return
|
|
445
|
+
|
|
446
|
+
self.conn.executemany(
|
|
447
|
+
"INSERT OR REPLACE INTO cache (hash, text, embedding) VALUES (?, ?, ?)",
|
|
448
|
+
data,
|
|
449
|
+
)
|
|
450
|
+
self.conn.commit()
|
|
451
|
+
|
|
452
|
+
def precompute_embeddings(self, texts: list[str]) -> None:
|
|
453
|
+
"""
|
|
454
|
+
Precompute embeddings for a large list of texts efficiently.
|
|
455
|
+
This is optimized for bulk operations when you know all texts upfront.
|
|
456
|
+
"""
|
|
457
|
+
if not texts:
|
|
458
|
+
return
|
|
459
|
+
|
|
460
|
+
# Remove duplicates while preserving order
|
|
461
|
+
unique_texts = list(dict.fromkeys(texts))
|
|
462
|
+
if self.verbose:
|
|
463
|
+
print(f"Precomputing embeddings for {len(unique_texts)} unique texts...")
|
|
464
|
+
|
|
465
|
+
# Check which ones are already cached
|
|
466
|
+
hashes = [self._hash_text(t) for t in unique_texts]
|
|
467
|
+
existing_hashes = set()
|
|
468
|
+
|
|
469
|
+
# Bulk check for existing embeddings
|
|
470
|
+
chunk_size = self.config["sqlite_chunk_size"]
|
|
471
|
+
for i in range(0, len(hashes), chunk_size):
|
|
472
|
+
chunk = hashes[i : i + chunk_size]
|
|
473
|
+
placeholders = ",".join("?" * len(chunk))
|
|
474
|
+
rows = self.conn.execute(
|
|
475
|
+
f"SELECT hash FROM cache WHERE hash IN ({placeholders})",
|
|
476
|
+
chunk,
|
|
477
|
+
).fetchall()
|
|
478
|
+
existing_hashes.update(h[0] for h in rows)
|
|
479
|
+
|
|
480
|
+
# Find missing texts
|
|
481
|
+
missing_items = [
|
|
482
|
+
(t, h) for t, h in zip(unique_texts, hashes) if h not in existing_hashes
|
|
483
|
+
]
|
|
484
|
+
|
|
485
|
+
if not missing_items:
|
|
486
|
+
if self.verbose:
|
|
487
|
+
print("All texts already cached!")
|
|
488
|
+
return
|
|
489
|
+
|
|
490
|
+
if self.verbose:
|
|
491
|
+
print(f"Computing {len(missing_items)} missing embeddings...")
|
|
492
|
+
missing_texts = [t for t, _ in missing_items]
|
|
493
|
+
embeds = self._get_embeddings(missing_texts)
|
|
494
|
+
|
|
495
|
+
# Prepare batch data for bulk insert
|
|
496
|
+
bulk_insert_data: list[tuple[str, str, bytes]] = []
|
|
497
|
+
for (text, h), vec in zip(missing_items, embeds):
|
|
498
|
+
arr = np.asarray(vec, dtype=np.float32)
|
|
499
|
+
bulk_insert_data.append((h, text, arr.tobytes()))
|
|
500
|
+
|
|
501
|
+
self._bulk_insert(bulk_insert_data)
|
|
502
|
+
if self.verbose:
|
|
503
|
+
print(f"Successfully cached {len(missing_items)} new embeddings!")
|
|
504
|
+
|
|
505
|
+
def get_cache_stats(self) -> dict[str, int]:
|
|
506
|
+
"""Get statistics about the cache."""
|
|
507
|
+
cursor = self.conn.execute("SELECT COUNT(*) FROM cache")
|
|
508
|
+
count = cursor.fetchone()[0]
|
|
509
|
+
return {"total_cached": count}
|
|
510
|
+
|
|
511
|
+
def clear_cache(self) -> None:
|
|
512
|
+
"""Clear all cached embeddings."""
|
|
513
|
+
self.conn.execute("DELETE FROM cache")
|
|
514
|
+
self.conn.commit()
|
|
515
|
+
|
|
516
|
+
def get_config(self) -> Dict[str, Any]:
|
|
517
|
+
"""Get current configuration."""
|
|
518
|
+
return {
|
|
519
|
+
"url_or_model": self.url_or_model,
|
|
520
|
+
"backend": self.backend,
|
|
521
|
+
"embed_size": self.embed_size,
|
|
522
|
+
"db_path": str(self.db_path),
|
|
523
|
+
"verbose": self.verbose,
|
|
524
|
+
"lazy": self.lazy,
|
|
525
|
+
**self.config
|
|
526
|
+
}
|
|
527
|
+
|
|
528
|
+
def update_config(self, **kwargs) -> None:
|
|
529
|
+
"""Update configuration parameters."""
|
|
530
|
+
for key, value in kwargs.items():
|
|
531
|
+
if key in self.config:
|
|
532
|
+
self.config[key] = value
|
|
533
|
+
elif key == "verbose":
|
|
534
|
+
self.verbose = value
|
|
535
|
+
elif key == "lazy":
|
|
536
|
+
self.lazy = value
|
|
537
|
+
else:
|
|
538
|
+
raise ValueError(f"Unknown configuration parameter: {key}")
|
|
539
|
+
|
|
540
|
+
# Reset model if backend-specific parameters changed
|
|
541
|
+
backend_params = {
|
|
542
|
+
"vllm": ["vllm_gpu_memory_utilization", "vllm_tensor_parallel_size", "vllm_dtype",
|
|
543
|
+
"vllm_trust_remote_code", "vllm_max_model_len"],
|
|
544
|
+
"transformers": ["transformers_device", "transformers_batch_size",
|
|
545
|
+
"transformers_normalize_embeddings", "transformers_trust_remote_code"],
|
|
546
|
+
"openai": ["api_key", "model_name"]
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
if any(param in kwargs for param in backend_params.get(self.backend, [])):
|
|
550
|
+
self._model = None # Force reload on next use
|
|
551
|
+
if self.backend == "openai":
|
|
552
|
+
self._client = None
|
|
553
|
+
|
|
554
|
+
def __del__(self) -> None:
|
|
555
|
+
"""Clean up database connection."""
|
|
556
|
+
if hasattr(self, "conn"):
|
|
557
|
+
self.conn.close()
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Type definitions for the embed_cache package."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Dict, Any, Union, Optional, Tuple
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
|
|
7
|
+
# Type aliases
|
|
8
|
+
TextList = List[str]
|
|
9
|
+
EmbeddingArray = NDArray[np.float32]
|
|
10
|
+
EmbeddingList = List[List[float]]
|
|
11
|
+
CacheStats = Dict[str, int]
|
|
12
|
+
ModelIdentifier = str # Either URL or model name/path
|
|
13
|
+
|
|
14
|
+
# For backwards compatibility
|
|
15
|
+
Embeddings = Union[EmbeddingArray, EmbeddingList]
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""Utility functions for the embed_cache package."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
def get_default_cache_path() -> str:
|
|
7
|
+
"""Get the default cache path based on environment."""
|
|
8
|
+
cache_dir = os.getenv("EMBED_CACHE_DIR", ".")
|
|
9
|
+
return os.path.join(cache_dir, "embed_cache.sqlite")
|
|
10
|
+
|
|
11
|
+
def validate_model_name(model_name: str) -> bool:
|
|
12
|
+
"""Validate if a model name is supported."""
|
|
13
|
+
# Check if it's a URL
|
|
14
|
+
if model_name.startswith("http"):
|
|
15
|
+
return True
|
|
16
|
+
|
|
17
|
+
# Check if it's a valid model path/name
|
|
18
|
+
supported_prefixes = [
|
|
19
|
+
"Qwen/",
|
|
20
|
+
"sentence-transformers/",
|
|
21
|
+
"BAAI/",
|
|
22
|
+
"intfloat/",
|
|
23
|
+
"microsoft/",
|
|
24
|
+
"nvidia/",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
return any(model_name.startswith(prefix) for prefix in supported_prefixes) or os.path.exists(model_name)
|
|
28
|
+
|
|
29
|
+
def estimate_cache_size(num_texts: int, embedding_dim: int = 1024) -> str:
|
|
30
|
+
"""Estimate cache size for given number of texts."""
|
|
31
|
+
# Rough estimate: hash (40 bytes) + text (avg 100 bytes) + embedding (embedding_dim * 4 bytes)
|
|
32
|
+
bytes_per_entry = 40 + 100 + (embedding_dim * 4)
|
|
33
|
+
total_bytes = num_texts * bytes_per_entry
|
|
34
|
+
|
|
35
|
+
if total_bytes < 1024:
|
|
36
|
+
return f"{total_bytes} bytes"
|
|
37
|
+
elif total_bytes < 1024 * 1024:
|
|
38
|
+
return f"{total_bytes / 1024:.1f} KB"
|
|
39
|
+
elif total_bytes < 1024 * 1024 * 1024:
|
|
40
|
+
return f"{total_bytes / (1024 * 1024):.1f} MB"
|
|
41
|
+
else:
|
|
42
|
+
return f"{total_bytes / (1024 * 1024 * 1024):.1f} GB"
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
# • memoize(func) -> Callable - Function result caching decorator
|
|
19
19
|
# • identify(obj: Any) -> str - Generate unique object identifier
|
|
20
20
|
# • identify_uuid(obj: Any) -> str - Generate UUID-based object identifier
|
|
21
|
-
# • load_by_ext(fname: str
|
|
21
|
+
# • load_by_ext(fname: Union[str, list[str]]) -> Any - Auto-detect file format loader
|
|
22
22
|
# • dump_json_or_pickle(obj: Any, fname: str) -> None - Smart file serializer
|
|
23
23
|
# • load_json_or_pickle(fname: str) -> Any - Smart file deserializer
|
|
24
24
|
# • multi_thread(func, items, **kwargs) -> list - Parallel thread execution
|
|
@@ -5,7 +5,7 @@ import re
|
|
|
5
5
|
import sys
|
|
6
6
|
import time
|
|
7
7
|
from collections import OrderedDict
|
|
8
|
-
from typing import Annotated, Literal
|
|
8
|
+
from typing import Annotated, Literal, Union
|
|
9
9
|
|
|
10
10
|
from loguru import logger
|
|
11
11
|
|
|
@@ -166,7 +166,7 @@ def log(
|
|
|
166
166
|
*,
|
|
167
167
|
level: Literal["info", "warning", "error", "critical", "success"] = "info",
|
|
168
168
|
once: bool = False,
|
|
169
|
-
interval: float
|
|
169
|
+
interval: Union[float, None] = None,
|
|
170
170
|
) -> None:
|
|
171
171
|
"""
|
|
172
172
|
Log a message using loguru with optional `once` and `interval` control.
|
|
@@ -7,7 +7,7 @@ import pickle
|
|
|
7
7
|
import time
|
|
8
8
|
from glob import glob
|
|
9
9
|
from pathlib import Path
|
|
10
|
-
from typing import Any
|
|
10
|
+
from typing import Any, Union
|
|
11
11
|
|
|
12
12
|
from json_repair import loads as jloads
|
|
13
13
|
from pydantic import BaseModel
|
|
@@ -92,7 +92,7 @@ def load_jsonl(path):
|
|
|
92
92
|
return [json.loads(line) for line in lines]
|
|
93
93
|
|
|
94
94
|
|
|
95
|
-
def load_by_ext(fname: str
|
|
95
|
+
def load_by_ext(fname: Union[str, list[str]], do_memoize: bool = False) -> Any:
|
|
96
96
|
"""
|
|
97
97
|
Load data based on file extension.
|
|
98
98
|
"""
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
import copy
|
|
4
4
|
import pprint
|
|
5
5
|
import textwrap
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any, Union
|
|
7
7
|
|
|
8
8
|
from tabulate import tabulate
|
|
9
9
|
|
|
@@ -24,17 +24,17 @@ def flatten_dict(d, parent_key="", sep="."):
|
|
|
24
24
|
|
|
25
25
|
def fprint(
|
|
26
26
|
input_data: Any,
|
|
27
|
-
key_ignore: list[str]
|
|
28
|
-
key_keep: list[str]
|
|
27
|
+
key_ignore: Union[list[str], None] = None,
|
|
28
|
+
key_keep: Union[list[str], None] = None,
|
|
29
29
|
max_width: int = 100,
|
|
30
30
|
indent: int = 2,
|
|
31
|
-
depth: int
|
|
31
|
+
depth: Union[int, None] = None,
|
|
32
32
|
table_format: str = "grid",
|
|
33
33
|
str_wrap_width: int = 80,
|
|
34
34
|
grep=None,
|
|
35
35
|
is_notebook=None,
|
|
36
36
|
f=print,
|
|
37
|
-
) -> None
|
|
37
|
+
) -> Union[None, str]:
|
|
38
38
|
"""
|
|
39
39
|
Pretty print structured data.
|
|
40
40
|
"""
|
|
@@ -4,7 +4,7 @@ import traceback
|
|
|
4
4
|
from collections.abc import Callable, Iterable, Iterator
|
|
5
5
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
6
6
|
from itertools import islice
|
|
7
|
-
from typing import Any, TypeVar, cast
|
|
7
|
+
from typing import Any, TypeVar, Union, cast
|
|
8
8
|
|
|
9
9
|
T = TypeVar("T")
|
|
10
10
|
|
|
@@ -65,12 +65,12 @@ def multi_process(
|
|
|
65
65
|
func: Callable[[Any], Any],
|
|
66
66
|
inputs: Iterable[Any],
|
|
67
67
|
*,
|
|
68
|
-
workers: int
|
|
68
|
+
workers: Union[int, None] = None,
|
|
69
69
|
batch: int = 1,
|
|
70
70
|
ordered: bool = True,
|
|
71
71
|
progress: bool = False,
|
|
72
|
-
inflight: int
|
|
73
|
-
timeout: float
|
|
72
|
+
inflight: Union[int, None] = None,
|
|
73
|
+
timeout: Union[float, None] = None,
|
|
74
74
|
stop_on_error: bool = True,
|
|
75
75
|
process_update_interval=10,
|
|
76
76
|
for_loop: bool = False,
|
|
@@ -83,7 +83,7 @@ import traceback
|
|
|
83
83
|
from collections.abc import Callable, Iterable
|
|
84
84
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
85
85
|
from itertools import islice
|
|
86
|
-
from typing import Any, TypeVar
|
|
86
|
+
from typing import Any, TypeVar, Union
|
|
87
87
|
|
|
88
88
|
from loguru import logger
|
|
89
89
|
|
|
@@ -125,16 +125,16 @@ def multi_thread(
|
|
|
125
125
|
func: Callable,
|
|
126
126
|
inputs: Iterable[Any],
|
|
127
127
|
*,
|
|
128
|
-
workers: int
|
|
128
|
+
workers: Union[int, None] = DEFAULT_WORKERS,
|
|
129
129
|
batch: int = 1,
|
|
130
130
|
ordered: bool = True,
|
|
131
131
|
progress: bool = True,
|
|
132
132
|
progress_update: int = 10,
|
|
133
133
|
prefetch_factor: int = 4,
|
|
134
|
-
timeout: float
|
|
134
|
+
timeout: Union[float, None] = None,
|
|
135
135
|
stop_on_error: bool = True,
|
|
136
136
|
n_proc=0,
|
|
137
|
-
store_output_pkl_file: str
|
|
137
|
+
store_output_pkl_file: Union[str, None] = None,
|
|
138
138
|
**fixed_kwargs,
|
|
139
139
|
) -> list[Any]:
|
|
140
140
|
"""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/scripts/openapi_client_codegen.py
RENAMED
|
File without changes
|