speedy-utils 1.1.12__tar.gz → 1.1.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/PKG-INFO +1 -1
  2. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/pyproject.toml +1 -1
  3. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/llm_utils/__init__.py +4 -1
  4. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/llm_utils/lm/async_lm/async_lm.py +2 -1
  5. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/llm_utils/lm/async_lm/async_lm_base.py +6 -6
  6. speedy_utils-1.1.13/src/llm_utils/vector_cache/__init__.py +25 -0
  7. speedy_utils-1.1.13/src/llm_utils/vector_cache/cli.py +200 -0
  8. speedy_utils-1.1.13/src/llm_utils/vector_cache/core.py +538 -0
  9. speedy_utils-1.1.13/src/llm_utils/vector_cache/types.py +15 -0
  10. speedy_utils-1.1.13/src/llm_utils/vector_cache/utils.py +42 -0
  11. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/speedy_utils/common/logger.py +2 -2
  12. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/speedy_utils/common/utils_io.py +1 -1
  13. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/speedy_utils/multi_worker/process.py +4 -4
  14. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/speedy_utils/multi_worker/thread.py +4 -4
  15. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/README.md +0 -0
  16. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/llm_utils/chat_format/__init__.py +0 -0
  17. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/llm_utils/chat_format/display.py +0 -0
  18. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/llm_utils/chat_format/transform.py +0 -0
  19. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/llm_utils/chat_format/utils.py +0 -0
  20. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/llm_utils/group_messages.py +0 -0
  21. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/llm_utils/lm/__init__.py +0 -0
  22. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/llm_utils/lm/async_lm/__init__.py +0 -0
  23. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/llm_utils/lm/async_lm/_utils.py +0 -0
  24. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/llm_utils/lm/async_lm/async_llm_task.py +0 -0
  25. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/llm_utils/lm/async_lm/lm_specific.py +0 -0
  26. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/llm_utils/lm/openai_memoize.py +0 -0
  27. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/llm_utils/lm/utils.py +0 -0
  28. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/llm_utils/scripts/README.md +0 -0
  29. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/llm_utils/scripts/vllm_load_balancer.py +0 -0
  30. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/llm_utils/scripts/vllm_serve.py +0 -0
  31. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/speedy_utils/__init__.py +0 -0
  32. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/speedy_utils/all.py +0 -0
  33. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/speedy_utils/common/__init__.py +0 -0
  34. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/speedy_utils/common/clock.py +0 -0
  35. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/speedy_utils/common/function_decorator.py +0 -0
  36. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/speedy_utils/common/notebook_utils.py +0 -0
  37. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/speedy_utils/common/report_manager.py +0 -0
  38. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/speedy_utils/common/utils_cache.py +0 -0
  39. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/speedy_utils/common/utils_misc.py +0 -0
  40. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/speedy_utils/common/utils_print.py +0 -0
  41. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/speedy_utils/multi_worker/__init__.py +0 -0
  42. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/speedy_utils/scripts/__init__.py +0 -0
  43. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/speedy_utils/scripts/mpython.py +0 -0
  44. {speedy_utils-1.1.12 → speedy_utils-1.1.13}/src/speedy_utils/scripts/openapi_client_codegen.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: speedy-utils
3
- Version: 1.1.12
3
+ Version: 1.1.13
4
4
  Summary: Fast and easy-to-use package for data science
5
5
  Author: AnhVTH
6
6
  Author-email: anhvth.226@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "speedy-utils"
3
- version = "1.1.12"
3
+ version = "1.1.13"
4
4
  description = "Fast and easy-to-use package for data science"
5
5
  authors = ["AnhVTH <anhvth.226@gmail.com>"]
6
6
  readme = "README.md"
@@ -1,4 +1,6 @@
1
1
  from llm_utils.lm.openai_memoize import MOpenAI
2
+ from llm_utils.vector_cache import VectorCache
3
+
2
4
  from .chat_format import (
3
5
  build_chatml_input,
4
6
  display_chat_messages_as_html,
@@ -24,5 +26,6 @@ __all__ = [
24
26
  "display_chat_messages_as_html",
25
27
  "AsyncLM",
26
28
  "AsyncLLMTask",
27
- "MOpenAI"
29
+ "MOpenAI",
30
+ "VectorCache",
28
31
  ]
@@ -5,6 +5,7 @@ from typing import (
5
5
  Literal,
6
6
  Optional,
7
7
  Type,
8
+ Union,
8
9
  cast,
9
10
  )
10
11
 
@@ -49,7 +50,7 @@ class AsyncLM(AsyncLMBase):
49
50
  temperature: float = 0.0,
50
51
  max_tokens: int = 2_000,
51
52
  host: str = "localhost",
52
- port: Optional[int | str] = None,
53
+ port: Optional[Union[int, str]] = None,
53
54
  base_url: Optional[str] = None,
54
55
  api_key: Optional[str] = None,
55
56
  cache: bool = True,
@@ -41,7 +41,7 @@ class AsyncLMBase:
41
41
  self,
42
42
  *,
43
43
  host: str = "localhost",
44
- port: Optional[int | str] = None,
44
+ port: Optional[Union[int, str]] = None,
45
45
  base_url: Optional[str] = None,
46
46
  api_key: Optional[str] = None,
47
47
  cache: bool = True,
@@ -81,8 +81,8 @@ class AsyncLMBase:
81
81
  async def __call__( # type: ignore
82
82
  self,
83
83
  *,
84
- prompt: str | None = ...,
85
- messages: RawMsgs | None = ...,
84
+ prompt: Optional[str] = ...,
85
+ messages: Optional[RawMsgs] = ...,
86
86
  response_format: type[str] = str,
87
87
  return_openai_response: bool = ...,
88
88
  **kwargs: Any,
@@ -92,8 +92,8 @@ class AsyncLMBase:
92
92
  async def __call__(
93
93
  self,
94
94
  *,
95
- prompt: str | None = ...,
96
- messages: RawMsgs | None = ...,
95
+ prompt: Optional[str] = ...,
96
+ messages: Optional[RawMsgs] = ...,
97
97
  response_format: Type[TModel],
98
98
  return_openai_response: bool = ...,
99
99
  **kwargs: Any,
@@ -137,7 +137,7 @@ class AsyncLMBase:
137
137
  @staticmethod
138
138
  def _parse_output(
139
139
  raw_response: Any, response_format: Union[type[str], Type[BaseModel]]
140
- ) -> str | BaseModel:
140
+ ) -> Union[str, BaseModel]:
141
141
  if hasattr(raw_response, "model_dump"):
142
142
  raw_response = raw_response.model_dump()
143
143
 
@@ -0,0 +1,25 @@
1
+ """
2
+ Efficient embedding caching system using vLLM for offline embeddings.
3
+
4
+ This package provides a fast, SQLite-backed caching layer for text embeddings,
5
+ supporting both OpenAI API and local models via vLLM.
6
+
7
+ Classes:
8
+ VectorCache: Main class for embedding computation and caching
9
+
10
+ Example:
11
+ # Using local model
12
+ cache = VectorCache("Qwen/Qwen3-Embedding-0.6B")
13
+ embeddings = cache.embeds(["Hello world", "How are you?"])
14
+
15
+ # Using OpenAI API
16
+ cache = VectorCache("https://api.openai.com/v1")
17
+ embeddings = cache.embeds(["Hello world", "How are you?"])
18
+ """
19
+
20
+ from .core import VectorCache
21
+ from .utils import get_default_cache_path, validate_model_name, estimate_cache_size
22
+
23
+ __version__ = "0.1.0"
24
+ __author__ = "AnhVTH <anhvth.226@gmail.com>"
25
+ __all__ = ["VectorCache", "get_default_cache_path", "validate_model_name", "estimate_cache_size"]
@@ -0,0 +1,200 @@
1
+ #!/usr/bin/env python3
2
+ """Command-line interface for embed_cache package."""
3
+
4
+ import argparse
5
+ import json
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import List
9
+
10
+ from llm_utils.vector_cache import VectorCache, estimate_cache_size, validate_model_name
11
+
12
+
13
+ def main():
14
+ """Main CLI entry point."""
15
+ parser = argparse.ArgumentParser(
16
+ description="Command-line interface for embed_cache package"
17
+ )
18
+
19
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
20
+
21
+ # Embed command
22
+ embed_parser = subparsers.add_parser("embed", help="Generate embeddings for texts")
23
+ embed_parser.add_argument("model", help="Model name or API URL")
24
+ embed_parser.add_argument("--texts", nargs="+", help="Texts to embed")
25
+ embed_parser.add_argument("--file", help="File containing texts (one per line)")
26
+ embed_parser.add_argument("--output", help="Output file for embeddings (JSON)")
27
+ embed_parser.add_argument(
28
+ "--cache-db", default="embed_cache.sqlite", help="Cache database path"
29
+ )
30
+ embed_parser.add_argument(
31
+ "--backend",
32
+ choices=["vllm", "transformers", "openai", "auto"],
33
+ default="auto",
34
+ help="Backend to use",
35
+ )
36
+ embed_parser.add_argument(
37
+ "--gpu-memory",
38
+ type=float,
39
+ default=0.5,
40
+ help="GPU memory utilization for vLLM (0.0-1.0)",
41
+ )
42
+ embed_parser.add_argument(
43
+ "--batch-size", type=int, default=32, help="Batch size for transformers"
44
+ )
45
+ embed_parser.add_argument(
46
+ "--verbose", action="store_true", help="Enable verbose output"
47
+ )
48
+
49
+ # Cache stats command
50
+ stats_parser = subparsers.add_parser("stats", help="Show cache statistics")
51
+ stats_parser.add_argument(
52
+ "--cache-db", default="embed_cache.sqlite", help="Cache database path"
53
+ )
54
+
55
+ # Clear cache command
56
+ clear_parser = subparsers.add_parser("clear", help="Clear cache")
57
+ clear_parser.add_argument(
58
+ "--cache-db", default="embed_cache.sqlite", help="Cache database path"
59
+ )
60
+ clear_parser.add_argument(
61
+ "--confirm", action="store_true", help="Skip confirmation prompt"
62
+ )
63
+
64
+ # Validate model command
65
+ validate_parser = subparsers.add_parser("validate", help="Validate model name")
66
+ validate_parser.add_argument("model", help="Model name to validate")
67
+
68
+ # Estimate command
69
+ estimate_parser = subparsers.add_parser("estimate", help="Estimate cache size")
70
+ estimate_parser.add_argument("num_texts", type=int, help="Number of texts")
71
+ estimate_parser.add_argument(
72
+ "--embed-dim", type=int, default=1024, help="Embedding dimension"
73
+ )
74
+
75
+ args = parser.parse_args()
76
+
77
+ if not args.command:
78
+ parser.print_help()
79
+ return
80
+
81
+ try:
82
+ if args.command == "embed":
83
+ handle_embed(args)
84
+ elif args.command == "stats":
85
+ handle_stats(args)
86
+ elif args.command == "clear":
87
+ handle_clear(args)
88
+ elif args.command == "validate":
89
+ handle_validate(args)
90
+ elif args.command == "estimate":
91
+ handle_estimate(args)
92
+ except Exception as e:
93
+ print(f"Error: {e}", file=sys.stderr)
94
+ sys.exit(1)
95
+
96
+
97
+ def handle_embed(args):
98
+ """Handle embed command."""
99
+ # Get texts
100
+ texts = []
101
+ if args.texts:
102
+ texts.extend(args.texts)
103
+
104
+ if args.file:
105
+ file_path = Path(args.file)
106
+ if not file_path.exists():
107
+ raise FileNotFoundError(f"File not found: {args.file}")
108
+
109
+ with open(file_path, "r", encoding="utf-8") as f:
110
+ texts.extend([line.strip() for line in f if line.strip()])
111
+
112
+ if not texts:
113
+ raise ValueError("No texts provided. Use --texts or --file")
114
+
115
+ print(f"Embedding {len(texts)} texts using model: {args.model}")
116
+
117
+ # Initialize cache and get embeddings
118
+ cache = VectorCache(args.model, db_path=args.cache_db)
119
+ embeddings = cache.embeds(texts)
120
+
121
+ print(f"Generated embeddings with shape: {embeddings.shape}")
122
+
123
+ # Output results
124
+ if args.output:
125
+ output_data = {
126
+ "texts": texts,
127
+ "embeddings": embeddings.tolist(),
128
+ "shape": list(embeddings.shape),
129
+ "model": args.model,
130
+ }
131
+
132
+ with open(args.output, "w", encoding="utf-8") as f:
133
+ json.dump(output_data, f, indent=2)
134
+
135
+ print(f"Results saved to: {args.output}")
136
+ else:
137
+ print(f"Embeddings shape: {embeddings.shape}")
138
+ print(f"Sample embedding (first 5 dims): {embeddings[0][:5].tolist()}")
139
+
140
+
141
+ def handle_stats(args):
142
+ """Handle stats command."""
143
+ cache_path = Path(args.cache_db)
144
+ if not cache_path.exists():
145
+ print(f"Cache database not found: {args.cache_db}")
146
+ return
147
+
148
+ cache = VectorCache("dummy", db_path=args.cache_db)
149
+ stats = cache.get_cache_stats()
150
+
151
+ print("Cache Statistics:")
152
+ print(f" Database: {args.cache_db}")
153
+ print(f" Total cached embeddings: {stats['total_cached']}")
154
+ print(f" Database size: {cache_path.stat().st_size / (1024 * 1024):.2f} MB")
155
+
156
+
157
+ def handle_clear(args):
158
+ """Handle clear command."""
159
+ cache_path = Path(args.cache_db)
160
+ if not cache_path.exists():
161
+ print(f"Cache database not found: {args.cache_db}")
162
+ return
163
+
164
+ if not args.confirm:
165
+ response = input(
166
+ f"Are you sure you want to clear cache at {args.cache_db}? [y/N]: "
167
+ )
168
+ if response.lower() != "y":
169
+ print("Cancelled.")
170
+ return
171
+
172
+ cache = VectorCache("dummy", db_path=args.cache_db)
173
+ stats_before = cache.get_cache_stats()
174
+ cache.clear_cache()
175
+
176
+ print(
177
+ f"Cleared {stats_before['total_cached']} cached embeddings from {args.cache_db}"
178
+ )
179
+
180
+
181
+ def handle_validate(args):
182
+ """Handle validate command."""
183
+ is_valid = validate_model_name(args.model)
184
+ if is_valid:
185
+ print(f"✓ Valid model: {args.model}")
186
+ else:
187
+ print(f"✗ Invalid model: {args.model}")
188
+ sys.exit(1)
189
+
190
+
191
+ def handle_estimate(args):
192
+ """Handle estimate command."""
193
+ size_estimate = estimate_cache_size(args.num_texts, args.embed_dim)
194
+ print(
195
+ f"Estimated cache size for {args.num_texts} texts ({args.embed_dim}D embeddings): {size_estimate}"
196
+ )
197
+
198
+
199
+ if __name__ == "__main__":
200
+ main()
@@ -0,0 +1,538 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import os
5
+ import sqlite3
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Literal, Optional, cast
8
+
9
+ import numpy as np
10
+
11
+
12
+ class VectorCache:
13
+ """
14
+ A caching layer for text embeddings with support for multiple backends.
15
+
16
+ Examples:
17
+ # OpenAI API
18
+ from llm_utils import VectorCache
19
+ cache = VectorCache("https://api.openai.com/v1", api_key="your-key")
20
+ embeddings = cache.embeds(["Hello world", "How are you?"])
21
+
22
+ # Custom OpenAI-compatible server (auto-detects model)
23
+ cache = VectorCache("http://localhost:8000/v1", api_key="abc")
24
+
25
+ # Transformers (Sentence Transformers)
26
+ cache = VectorCache("sentence-transformers/all-MiniLM-L6-v2")
27
+
28
+ # vLLM (local model)
29
+ cache = VectorCache("/path/to/model")
30
+
31
+ # Explicit backend specification
32
+ cache = VectorCache("model-name", backend="transformers")
33
+
34
+ # Lazy loading (default: True) - load model only when needed
35
+ cache = VectorCache("model-name", lazy=True)
36
+
37
+ # Eager loading - load model immediately
38
+ cache = VectorCache("model-name", lazy=False)
39
+ """
40
+ def __init__(
41
+ self,
42
+ url_or_model: str,
43
+ backend: Optional[Literal["vllm", "transformers", "openai"]] = None,
44
+ embed_size: Optional[int] = None,
45
+ db_path: Optional[str] = None,
46
+ # OpenAI API parameters
47
+ api_key: Optional[str] = "abc",
48
+ model_name: Optional[str] = None,
49
+ # vLLM parameters
50
+ vllm_gpu_memory_utilization: float = 0.5,
51
+ vllm_tensor_parallel_size: int = 1,
52
+ vllm_dtype: str = "auto",
53
+ vllm_trust_remote_code: bool = False,
54
+ vllm_max_model_len: Optional[int] = None,
55
+ # Transformers parameters
56
+ transformers_device: str = "auto",
57
+ transformers_batch_size: int = 32,
58
+ transformers_normalize_embeddings: bool = True,
59
+ transformers_trust_remote_code: bool = False,
60
+ # SQLite parameters
61
+ sqlite_chunk_size: int = 999,
62
+ sqlite_cache_size: int = 10000,
63
+ sqlite_mmap_size: int = 268435456,
64
+ # Other parameters
65
+ verbose: bool = True,
66
+ lazy: bool = True,
67
+ ) -> None:
68
+ self.url_or_model = url_or_model
69
+ self.embed_size = embed_size
70
+ self.verbose = verbose
71
+ self.lazy = lazy
72
+
73
+ self.backend = self._determine_backend(backend)
74
+ if self.verbose and backend is None:
75
+ print(f"Auto-detected backend: {self.backend}")
76
+
77
+ # Store all configuration parameters
78
+ self.config = {
79
+ # OpenAI
80
+ "api_key": api_key or os.getenv("OPENAI_API_KEY"),
81
+ "model_name": model_name,
82
+ # vLLM
83
+ "vllm_gpu_memory_utilization": vllm_gpu_memory_utilization,
84
+ "vllm_tensor_parallel_size": vllm_tensor_parallel_size,
85
+ "vllm_dtype": vllm_dtype,
86
+ "vllm_trust_remote_code": vllm_trust_remote_code,
87
+ "vllm_max_model_len": vllm_max_model_len,
88
+ # Transformers
89
+ "transformers_device": transformers_device,
90
+ "transformers_batch_size": transformers_batch_size,
91
+ "transformers_normalize_embeddings": transformers_normalize_embeddings,
92
+ "transformers_trust_remote_code": transformers_trust_remote_code,
93
+ # SQLite
94
+ "sqlite_chunk_size": sqlite_chunk_size,
95
+ "sqlite_cache_size": sqlite_cache_size,
96
+ "sqlite_mmap_size": sqlite_mmap_size,
97
+ }
98
+
99
+ # Auto-detect model_name for OpenAI if using custom URL and default model
100
+ if (self.backend == "openai" and
101
+ model_name == "text-embedding-3-small" and
102
+ self.url_or_model != "https://api.openai.com/v1"):
103
+ if self.verbose:
104
+ print(f"Attempting to auto-detect model from {self.url_or_model}...")
105
+ try:
106
+ import openai
107
+ client = openai.OpenAI(
108
+ base_url=self.url_or_model,
109
+ api_key=self.config["api_key"]
110
+ )
111
+ models = client.models.list()
112
+ if models.data:
113
+ detected_model = models.data[0].id
114
+ self.config["model_name"] = detected_model
115
+ model_name = detected_model # Update for db_path computation
116
+ if self.verbose:
117
+ print(f"Auto-detected model: {detected_model}")
118
+ else:
119
+ if self.verbose:
120
+ print("No models found, using default model")
121
+ except Exception as e:
122
+ if self.verbose:
123
+ print(f"Model auto-detection failed: {e}, using default model")
124
+ # Fallback to default if auto-detection fails
125
+ pass
126
+
127
+ # Set default db_path if not provided
128
+ if db_path is None:
129
+ if self.backend == "openai":
130
+ model_id = self.config["model_name"] or "openai-default"
131
+ else:
132
+ model_id = self.url_or_model
133
+ safe_name = hashlib.sha1(model_id.encode("utf-8")).hexdigest()[:16]
134
+ self.db_path = Path.home() / ".cache" / "embed" / f"{self.backend}_{safe_name}.sqlite"
135
+ else:
136
+ self.db_path = Path(db_path)
137
+
138
+ # Ensure the directory exists
139
+ self.db_path.parent.mkdir(parents=True, exist_ok=True)
140
+
141
+ self.conn = sqlite3.connect(self.db_path)
142
+ self._optimize_connection()
143
+ self._ensure_schema()
144
+ self._model = None # Lazy loading
145
+ self._client = None # For OpenAI client
146
+
147
+ # Load model/client if not lazy
148
+ if not self.lazy:
149
+ if self.backend == "openai":
150
+ self._load_openai_client()
151
+ elif self.backend in ["vllm", "transformers"]:
152
+ self._load_model()
153
+
154
+ def _determine_backend(self, backend: Optional[Literal["vllm", "transformers", "openai"]]) -> str:
155
+ """Determine the appropriate backend based on url_or_model and user preference."""
156
+ if backend is not None:
157
+ valid_backends = ["vllm", "transformers", "openai"]
158
+ if backend not in valid_backends:
159
+ raise ValueError(f"Invalid backend '{backend}'. Must be one of: {valid_backends}")
160
+ return backend
161
+
162
+ if self.url_or_model.startswith("http"):
163
+ return "openai"
164
+
165
+ # Default to vllm for local models
166
+ return "vllm"
167
+
168
+ def _optimize_connection(self) -> None:
169
+ """Optimize SQLite connection for bulk operations."""
170
+ # Performance optimizations for bulk operations
171
+ self.conn.execute(
172
+ "PRAGMA journal_mode=WAL"
173
+ ) # Write-Ahead Logging for better concurrency
174
+ self.conn.execute("PRAGMA synchronous=NORMAL") # Faster writes, still safe
175
+ self.conn.execute(f"PRAGMA cache_size={self.config['sqlite_cache_size']}") # Configurable cache
176
+ self.conn.execute("PRAGMA temp_store=MEMORY") # Use memory for temp storage
177
+ self.conn.execute(f"PRAGMA mmap_size={self.config['sqlite_mmap_size']}") # Configurable memory mapping
178
+
179
+ def _ensure_schema(self) -> None:
180
+ self.conn.execute("""
181
+ CREATE TABLE IF NOT EXISTS cache (
182
+ hash TEXT PRIMARY KEY,
183
+ text TEXT,
184
+ embedding BLOB
185
+ )
186
+ """)
187
+ # Add index for faster lookups if it doesn't exist
188
+ self.conn.execute("""
189
+ CREATE INDEX IF NOT EXISTS idx_cache_hash ON cache(hash)
190
+ """)
191
+ self.conn.commit()
192
+
193
+ def _load_openai_client(self) -> None:
194
+ """Load OpenAI client."""
195
+ import openai
196
+ self._client = openai.OpenAI(
197
+ base_url=self.url_or_model,
198
+ api_key=self.config["api_key"]
199
+ )
200
+
201
+ def _load_model(self) -> None:
202
+ """Load the model for vLLM or Transformers."""
203
+ if self.backend == "vllm":
204
+ from vllm import LLM
205
+
206
+ gpu_memory_utilization = cast(float, self.config["vllm_gpu_memory_utilization"])
207
+ tensor_parallel_size = cast(int, self.config["vllm_tensor_parallel_size"])
208
+ dtype = cast(str, self.config["vllm_dtype"])
209
+ trust_remote_code = cast(bool, self.config["vllm_trust_remote_code"])
210
+ max_model_len = cast(Optional[int], self.config["vllm_max_model_len"])
211
+
212
+ vllm_kwargs = {
213
+ "model": self.url_or_model,
214
+ "task": "embed",
215
+ "gpu_memory_utilization": gpu_memory_utilization,
216
+ "tensor_parallel_size": tensor_parallel_size,
217
+ "dtype": dtype,
218
+ "trust_remote_code": trust_remote_code,
219
+ }
220
+
221
+ if max_model_len is not None:
222
+ vllm_kwargs["max_model_len"] = max_model_len
223
+
224
+ try:
225
+ self._model = LLM(**vllm_kwargs)
226
+ except (ValueError, AssertionError, RuntimeError) as e:
227
+ error_msg = str(e).lower()
228
+ if ("kv cache" in error_msg and "gpu_memory_utilization" in error_msg) or \
229
+ ("memory" in error_msg and ("gpu" in error_msg or "insufficient" in error_msg)) or \
230
+ ("free memory" in error_msg and "initial" in error_msg) or \
231
+ ("engine core initialization failed" in error_msg):
232
+ raise ValueError(
233
+ f"Insufficient GPU memory for vLLM model initialization. "
234
+ f"Current vllm_gpu_memory_utilization ({gpu_memory_utilization}) may be too low. "
235
+ f"Try one of the following:\n"
236
+ f"1. Increase vllm_gpu_memory_utilization (e.g., 0.5, 0.8, or 0.9)\n"
237
+ f"2. Decrease vllm_max_model_len (e.g., 4096, 8192)\n"
238
+ f"3. Use a smaller model\n"
239
+ f"4. Ensure no other processes are using GPU memory during initialization\n"
240
+ f"Original error: {e}"
241
+ ) from e
242
+ else:
243
+ raise
244
+ elif self.backend == "transformers":
245
+ from transformers import AutoTokenizer, AutoModel
246
+ import torch
247
+
248
+ device = self.config["transformers_device"]
249
+ # Handle "auto" device selection - default to CPU for transformers to avoid memory conflicts
250
+ if device == "auto":
251
+ device = "cpu" # Default to CPU to avoid GPU memory conflicts with vLLM
252
+
253
+ tokenizer = AutoTokenizer.from_pretrained(self.url_or_model, padding_side='left', trust_remote_code=self.config["transformers_trust_remote_code"])
254
+ model = AutoModel.from_pretrained(self.url_or_model, trust_remote_code=self.config["transformers_trust_remote_code"])
255
+
256
+ # Move model to device
257
+ model.to(device)
258
+ model.eval()
259
+
260
+ self._model = {"tokenizer": tokenizer, "model": model, "device": device}
261
+
262
+ def _get_embeddings(self, texts: list[str]) -> list[list[float]]:
263
+ """Get embeddings using the configured backend."""
264
+ if self.backend == "openai":
265
+ return self._get_openai_embeddings(texts)
266
+ elif self.backend == "vllm":
267
+ return self._get_vllm_embeddings(texts)
268
+ elif self.backend == "transformers":
269
+ return self._get_transformers_embeddings(texts)
270
+ else:
271
+ raise ValueError(f"Unsupported backend: {self.backend}")
272
+
273
+ def _get_openai_embeddings(self, texts: list[str]) -> list[list[float]]:
274
+ """Get embeddings using OpenAI API."""
275
+ # Assert valid model_name for OpenAI backend
276
+ model_name = self.config["model_name"]
277
+ assert model_name is not None and model_name.strip(), f"Invalid model_name for OpenAI backend: {model_name}. Model name must be provided and non-empty."
278
+
279
+ if self._client is None:
280
+ self._load_openai_client()
281
+
282
+ response = self._client.embeddings.create( # type: ignore
283
+ model=model_name,
284
+ input=texts
285
+ )
286
+ embeddings = [item.embedding for item in response.data]
287
+ return embeddings
288
+
289
+ def _get_vllm_embeddings(self, texts: list[str]) -> list[list[float]]:
290
+ """Get embeddings using vLLM."""
291
+ if self._model is None:
292
+ self._load_model()
293
+
294
+ outputs = self._model.embed(texts) # type: ignore
295
+ embeddings = [o.outputs.embedding for o in outputs]
296
+ return embeddings
297
+
298
+ def _get_transformers_embeddings(self, texts: list[str]) -> list[list[float]]:
299
+ """Get embeddings using transformers directly."""
300
+ if self._model is None:
301
+ self._load_model()
302
+
303
+ if not isinstance(self._model, dict):
304
+ raise ValueError("Model not loaded properly for transformers backend")
305
+
306
+ tokenizer = self._model["tokenizer"]
307
+ model = self._model["model"]
308
+ device = self._model["device"]
309
+
310
+ normalize_embeddings = cast(bool, self.config["transformers_normalize_embeddings"])
311
+
312
+ # For now, use a default max_length
313
+ max_length = 8192
314
+
315
+ # Tokenize
316
+ batch_dict = tokenizer(
317
+ texts,
318
+ padding=True,
319
+ truncation=True,
320
+ max_length=max_length,
321
+ return_tensors="pt",
322
+ )
323
+
324
+ # Move to device
325
+ batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
326
+
327
+ # Run model
328
+ import torch
329
+ with torch.no_grad():
330
+ outputs = model(**batch_dict)
331
+
332
+ # Apply last token pooling
333
+ embeddings = self._last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
334
+
335
+ # Normalize if needed
336
+ if normalize_embeddings:
337
+ import torch.nn.functional as F
338
+ embeddings = F.normalize(embeddings, p=2, dim=1)
339
+
340
+ return embeddings.cpu().numpy().tolist()
341
+
342
+ def _last_token_pool(self, last_hidden_states, attention_mask):
343
+ """Apply last token pooling to get embeddings."""
344
+ import torch
345
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
346
+ if left_padding:
347
+ return last_hidden_states[:, -1]
348
+ else:
349
+ sequence_lengths = attention_mask.sum(dim=1) - 1
350
+ batch_size = last_hidden_states.shape[0]
351
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
352
+
353
+ def _hash_text(self, text: str) -> str:
354
+ return hashlib.sha1(text.encode("utf-8")).hexdigest()
355
+
356
+ def embeds(self, texts: list[str], cache: bool = True) -> np.ndarray:
357
+ """
358
+ Return embeddings for all texts.
359
+
360
+ If cache=True, compute and cache missing embeddings.
361
+ If cache=False, force recompute all embeddings and update cache.
362
+
363
+ This method processes lookups and embedding generation in chunks to
364
+ handle very large input lists. A tqdm progress bar is shown while
365
+ computing missing embeddings.
366
+ """
367
+ if not texts:
368
+ return np.empty((0, 0), dtype=np.float32)
369
+
370
+ hashes = [self._hash_text(t) for t in texts]
371
+
372
+ # Helper to yield chunks
373
+ def _chunks(lst: list[str], n: int) -> list[list[str]]:
374
+ return [lst[i : i + n] for i in range(0, len(lst), n)]
375
+
376
+ # Fetch known embeddings in bulk with optimized chunk size
377
+ hit_map: dict[str, np.ndarray] = {}
378
+ chunk_size = self.config["sqlite_chunk_size"]
379
+
380
+ # Use bulk lookup with optimized query
381
+ hash_chunks = _chunks(hashes, chunk_size)
382
+ for chunk in hash_chunks:
383
+ placeholders = ",".join("?" * len(chunk))
384
+ rows = self.conn.execute(
385
+ f"SELECT hash, embedding FROM cache WHERE hash IN ({placeholders})",
386
+ chunk,
387
+ ).fetchall()
388
+ for h, e in rows:
389
+ hit_map[h] = np.frombuffer(e, dtype=np.float32)
390
+
391
+ # Determine which texts are missing
392
+ if cache:
393
+ missing_items: list[tuple[str, str]] = [
394
+ (t, h) for t, h in zip(texts, hashes) if h not in hit_map
395
+ ]
396
+ else:
397
+ missing_items: list[tuple[str, str]] = [
398
+ (t, h) for t, h in zip(texts, hashes)
399
+ ]
400
+
401
+ if missing_items:
402
+ if self.verbose:
403
+ print(f"Computing embeddings for {len(missing_items)} missing texts...")
404
+ missing_texts = [t for t, _ in missing_items]
405
+ embeds = self._get_embeddings(missing_texts)
406
+
407
+ # Prepare batch data for bulk insert
408
+ bulk_insert_data: list[tuple[str, str, bytes]] = []
409
+ for (text, h), vec in zip(missing_items, embeds):
410
+ arr = np.asarray(vec, dtype=np.float32)
411
+ bulk_insert_data.append((h, text, arr.tobytes()))
412
+ hit_map[h] = arr
413
+
414
+ self._bulk_insert(bulk_insert_data)
415
+
416
+ # Return embeddings in the original order
417
+ return np.vstack([hit_map[h] for h in hashes])
418
+
419
+ def __call__(self, texts: list[str], cache: bool = True) -> np.ndarray:
420
+ return self.embeds(texts, cache)
421
+
422
+ def _bulk_insert(self, data: list[tuple[str, str, bytes]]) -> None:
423
+ """Perform bulk insert of embedding data."""
424
+ if not data:
425
+ return
426
+
427
+ self.conn.executemany(
428
+ "INSERT OR REPLACE INTO cache (hash, text, embedding) VALUES (?, ?, ?)",
429
+ data,
430
+ )
431
+ self.conn.commit()
432
+
433
+ def precompute_embeddings(self, texts: list[str]) -> None:
434
+ """
435
+ Precompute embeddings for a large list of texts efficiently.
436
+ This is optimized for bulk operations when you know all texts upfront.
437
+ """
438
+ if not texts:
439
+ return
440
+
441
+ # Remove duplicates while preserving order
442
+ unique_texts = list(dict.fromkeys(texts))
443
+ if self.verbose:
444
+ print(f"Precomputing embeddings for {len(unique_texts)} unique texts...")
445
+
446
+ # Check which ones are already cached
447
+ hashes = [self._hash_text(t) for t in unique_texts]
448
+ existing_hashes = set()
449
+
450
+ # Bulk check for existing embeddings
451
+ chunk_size = self.config["sqlite_chunk_size"]
452
+ for i in range(0, len(hashes), chunk_size):
453
+ chunk = hashes[i : i + chunk_size]
454
+ placeholders = ",".join("?" * len(chunk))
455
+ rows = self.conn.execute(
456
+ f"SELECT hash FROM cache WHERE hash IN ({placeholders})",
457
+ chunk,
458
+ ).fetchall()
459
+ existing_hashes.update(h[0] for h in rows)
460
+
461
+ # Find missing texts
462
+ missing_items = [
463
+ (t, h) for t, h in zip(unique_texts, hashes) if h not in existing_hashes
464
+ ]
465
+
466
+ if not missing_items:
467
+ if self.verbose:
468
+ print("All texts already cached!")
469
+ return
470
+
471
+ if self.verbose:
472
+ print(f"Computing {len(missing_items)} missing embeddings...")
473
+ missing_texts = [t for t, _ in missing_items]
474
+ embeds = self._get_embeddings(missing_texts)
475
+
476
+ # Prepare batch data for bulk insert
477
+ bulk_insert_data: list[tuple[str, str, bytes]] = []
478
+ for (text, h), vec in zip(missing_items, embeds):
479
+ arr = np.asarray(vec, dtype=np.float32)
480
+ bulk_insert_data.append((h, text, arr.tobytes()))
481
+
482
+ self._bulk_insert(bulk_insert_data)
483
+ if self.verbose:
484
+ print(f"Successfully cached {len(missing_items)} new embeddings!")
485
+
486
+ def get_cache_stats(self) -> dict[str, int]:
487
+ """Get statistics about the cache."""
488
+ cursor = self.conn.execute("SELECT COUNT(*) FROM cache")
489
+ count = cursor.fetchone()[0]
490
+ return {"total_cached": count}
491
+
492
+ def clear_cache(self) -> None:
493
+ """Clear all cached embeddings."""
494
+ self.conn.execute("DELETE FROM cache")
495
+ self.conn.commit()
496
+
497
+ def get_config(self) -> Dict[str, Any]:
498
+ """Get current configuration."""
499
+ return {
500
+ "url_or_model": self.url_or_model,
501
+ "backend": self.backend,
502
+ "embed_size": self.embed_size,
503
+ "db_path": str(self.db_path),
504
+ "verbose": self.verbose,
505
+ "lazy": self.lazy,
506
+ **self.config
507
+ }
508
+
509
+ def update_config(self, **kwargs) -> None:
510
+ """Update configuration parameters."""
511
+ for key, value in kwargs.items():
512
+ if key in self.config:
513
+ self.config[key] = value
514
+ elif key == "verbose":
515
+ self.verbose = value
516
+ elif key == "lazy":
517
+ self.lazy = value
518
+ else:
519
+ raise ValueError(f"Unknown configuration parameter: {key}")
520
+
521
+ # Reset model if backend-specific parameters changed
522
+ backend_params = {
523
+ "vllm": ["vllm_gpu_memory_utilization", "vllm_tensor_parallel_size", "vllm_dtype",
524
+ "vllm_trust_remote_code", "vllm_max_model_len"],
525
+ "transformers": ["transformers_device", "transformers_batch_size",
526
+ "transformers_normalize_embeddings", "transformers_trust_remote_code"],
527
+ "openai": ["api_key", "model_name"]
528
+ }
529
+
530
+ if any(param in kwargs for param in backend_params.get(self.backend, [])):
531
+ self._model = None # Force reload on next use
532
+ if self.backend == "openai":
533
+ self._client = None
534
+
535
+ def __del__(self) -> None:
536
+ """Clean up database connection."""
537
+ if hasattr(self, "conn"):
538
+ self.conn.close()
@@ -0,0 +1,15 @@
1
+ """Type definitions for the embed_cache package."""
2
+
3
+ from typing import List, Dict, Any, Union, Optional, Tuple
4
+ import numpy as np
5
+ from numpy.typing import NDArray
6
+
7
+ # Type aliases
8
+ TextList = List[str]
9
+ EmbeddingArray = NDArray[np.float32]
10
+ EmbeddingList = List[List[float]]
11
+ CacheStats = Dict[str, int]
12
+ ModelIdentifier = str # Either URL or model name/path
13
+
14
+ # For backwards compatibility
15
+ Embeddings = Union[EmbeddingArray, EmbeddingList]
@@ -0,0 +1,42 @@
1
+ """Utility functions for the embed_cache package."""
2
+
3
+ import os
4
+ from typing import Optional
5
+
6
+ def get_default_cache_path() -> str:
7
+ """Get the default cache path based on environment."""
8
+ cache_dir = os.getenv("EMBED_CACHE_DIR", ".")
9
+ return os.path.join(cache_dir, "embed_cache.sqlite")
10
+
11
+ def validate_model_name(model_name: str) -> bool:
12
+ """Validate if a model name is supported."""
13
+ # Check if it's a URL
14
+ if model_name.startswith("http"):
15
+ return True
16
+
17
+ # Check if it's a valid model path/name
18
+ supported_prefixes = [
19
+ "Qwen/",
20
+ "sentence-transformers/",
21
+ "BAAI/",
22
+ "intfloat/",
23
+ "microsoft/",
24
+ "nvidia/",
25
+ ]
26
+
27
+ return any(model_name.startswith(prefix) for prefix in supported_prefixes) or os.path.exists(model_name)
28
+
29
+ def estimate_cache_size(num_texts: int, embedding_dim: int = 1024) -> str:
30
+ """Estimate cache size for given number of texts."""
31
+ # Rough estimate: hash (40 bytes) + text (avg 100 bytes) + embedding (embedding_dim * 4 bytes)
32
+ bytes_per_entry = 40 + 100 + (embedding_dim * 4)
33
+ total_bytes = num_texts * bytes_per_entry
34
+
35
+ if total_bytes < 1024:
36
+ return f"{total_bytes} bytes"
37
+ elif total_bytes < 1024 * 1024:
38
+ return f"{total_bytes / 1024:.1f} KB"
39
+ elif total_bytes < 1024 * 1024 * 1024:
40
+ return f"{total_bytes / (1024 * 1024):.1f} MB"
41
+ else:
42
+ return f"{total_bytes / (1024 * 1024 * 1024):.1f} GB"
@@ -5,7 +5,7 @@ import re
5
5
  import sys
6
6
  import time
7
7
  from collections import OrderedDict
8
- from typing import Annotated, Literal
8
+ from typing import Annotated, Literal, Union
9
9
 
10
10
  from loguru import logger
11
11
 
@@ -166,7 +166,7 @@ def log(
166
166
  *,
167
167
  level: Literal["info", "warning", "error", "critical", "success"] = "info",
168
168
  once: bool = False,
169
- interval: float | None = None,
169
+ interval: Union[float, None] = None,
170
170
  ) -> None:
171
171
  """
172
172
  Log a message using loguru with optional `once` and `interval` control.
@@ -7,7 +7,7 @@ import pickle
7
7
  import time
8
8
  from glob import glob
9
9
  from pathlib import Path
10
- from typing import Any
10
+ from typing import Any, Union
11
11
 
12
12
  from json_repair import loads as jloads
13
13
  from pydantic import BaseModel
@@ -4,7 +4,7 @@ import traceback
4
4
  from collections.abc import Callable, Iterable, Iterator
5
5
  from concurrent.futures import ProcessPoolExecutor, as_completed
6
6
  from itertools import islice
7
- from typing import Any, TypeVar, cast
7
+ from typing import Any, TypeVar, Union, cast
8
8
 
9
9
  T = TypeVar("T")
10
10
 
@@ -65,12 +65,12 @@ def multi_process(
65
65
  func: Callable[[Any], Any],
66
66
  inputs: Iterable[Any],
67
67
  *,
68
- workers: int | None = None,
68
+ workers: Union[int, None] = None,
69
69
  batch: int = 1,
70
70
  ordered: bool = True,
71
71
  progress: bool = False,
72
- inflight: int | None = None,
73
- timeout: float | None = None,
72
+ inflight: Union[int, None] = None,
73
+ timeout: Union[float, None] = None,
74
74
  stop_on_error: bool = True,
75
75
  process_update_interval=10,
76
76
  for_loop: bool = False,
@@ -83,7 +83,7 @@ import traceback
83
83
  from collections.abc import Callable, Iterable
84
84
  from concurrent.futures import ThreadPoolExecutor, as_completed
85
85
  from itertools import islice
86
- from typing import Any, TypeVar
86
+ from typing import Any, TypeVar, Union
87
87
 
88
88
  from loguru import logger
89
89
 
@@ -125,16 +125,16 @@ def multi_thread(
125
125
  func: Callable,
126
126
  inputs: Iterable[Any],
127
127
  *,
128
- workers: int | None = DEFAULT_WORKERS,
128
+ workers: Union[int, None] = DEFAULT_WORKERS,
129
129
  batch: int = 1,
130
130
  ordered: bool = True,
131
131
  progress: bool = True,
132
132
  progress_update: int = 10,
133
133
  prefetch_factor: int = 4,
134
- timeout: float | None = None,
134
+ timeout: Union[float, None] = None,
135
135
  stop_on_error: bool = True,
136
136
  n_proc=0,
137
- store_output_pkl_file: str | None = None,
137
+ store_output_pkl_file: Union[str, None] = None,
138
138
  **fixed_kwargs,
139
139
  ) -> list[Any]:
140
140
  """
File without changes