speedy-utils 1.1.26__py3-none-any.whl → 1.1.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. llm_utils/__init__.py +16 -4
  2. llm_utils/chat_format/__init__.py +10 -10
  3. llm_utils/chat_format/display.py +33 -21
  4. llm_utils/chat_format/transform.py +17 -19
  5. llm_utils/chat_format/utils.py +6 -4
  6. llm_utils/group_messages.py +17 -14
  7. llm_utils/lm/__init__.py +6 -5
  8. llm_utils/lm/async_lm/__init__.py +1 -0
  9. llm_utils/lm/async_lm/_utils.py +10 -9
  10. llm_utils/lm/async_lm/async_llm_task.py +141 -137
  11. llm_utils/lm/async_lm/async_lm.py +48 -42
  12. llm_utils/lm/async_lm/async_lm_base.py +59 -60
  13. llm_utils/lm/async_lm/lm_specific.py +4 -3
  14. llm_utils/lm/base_prompt_builder.py +93 -70
  15. llm_utils/lm/llm.py +126 -108
  16. llm_utils/lm/llm_signature.py +4 -2
  17. llm_utils/lm/lm_base.py +72 -73
  18. llm_utils/lm/mixins.py +102 -62
  19. llm_utils/lm/openai_memoize.py +124 -87
  20. llm_utils/lm/signature.py +105 -92
  21. llm_utils/lm/utils.py +42 -23
  22. llm_utils/scripts/vllm_load_balancer.py +23 -30
  23. llm_utils/scripts/vllm_serve.py +8 -7
  24. llm_utils/vector_cache/__init__.py +9 -3
  25. llm_utils/vector_cache/cli.py +1 -1
  26. llm_utils/vector_cache/core.py +59 -63
  27. llm_utils/vector_cache/types.py +7 -5
  28. llm_utils/vector_cache/utils.py +12 -8
  29. speedy_utils/__imports.py +244 -0
  30. speedy_utils/__init__.py +90 -194
  31. speedy_utils/all.py +125 -227
  32. speedy_utils/common/clock.py +37 -42
  33. speedy_utils/common/function_decorator.py +6 -12
  34. speedy_utils/common/logger.py +43 -52
  35. speedy_utils/common/notebook_utils.py +13 -21
  36. speedy_utils/common/patcher.py +21 -17
  37. speedy_utils/common/report_manager.py +42 -44
  38. speedy_utils/common/utils_cache.py +152 -169
  39. speedy_utils/common/utils_io.py +137 -103
  40. speedy_utils/common/utils_misc.py +15 -21
  41. speedy_utils/common/utils_print.py +22 -28
  42. speedy_utils/multi_worker/process.py +66 -79
  43. speedy_utils/multi_worker/thread.py +78 -155
  44. speedy_utils/scripts/mpython.py +38 -36
  45. speedy_utils/scripts/openapi_client_codegen.py +10 -10
  46. {speedy_utils-1.1.26.dist-info → speedy_utils-1.1.28.dist-info}/METADATA +1 -1
  47. speedy_utils-1.1.28.dist-info/RECORD +57 -0
  48. vision_utils/README.md +202 -0
  49. vision_utils/__init__.py +5 -0
  50. vision_utils/io_utils.py +470 -0
  51. vision_utils/plot.py +345 -0
  52. speedy_utils-1.1.26.dist-info/RECORD +0 -52
  53. {speedy_utils-1.1.26.dist-info → speedy_utils-1.1.28.dist-info}/WHEEL +0 -0
  54. {speedy_utils-1.1.26.dist-info → speedy_utils-1.1.28.dist-info}/entry_points.txt +0 -0
llm_utils/lm/utils.py CHANGED
@@ -1,17 +1,17 @@
1
1
  import os
2
- import signal
3
- import time
4
- from typing import Any, List, Optional, cast
5
-
6
- from loguru import logger
7
-
8
2
 
9
3
  # Additional imports for VLLM utilities
10
4
  import re
5
+ import signal
11
6
  import subprocess
7
+ import time
8
+ from typing import Any, List, Optional, cast
9
+
12
10
  import requests
11
+ from loguru import logger
13
12
  from openai import OpenAI
14
13
 
14
+
15
15
  try:
16
16
  import psutil
17
17
 
@@ -19,10 +19,12 @@ try:
19
19
  except ImportError:
20
20
  HAS_PSUTIL = False
21
21
  psutil = cast(Any, None)
22
- logger.warning("psutil not available. Some VLLM process management features may be limited.")
22
+ logger.warning(
23
+ "psutil not available. Some VLLM process management features may be limited."
24
+ )
23
25
 
24
26
  # Global tracking of VLLM processes
25
- _VLLM_PROCESSES: List[subprocess.Popen] = []
27
+ _VLLM_PROCESSES: list[subprocess.Popen] = []
26
28
 
27
29
 
28
30
  def _extract_port_from_vllm_cmd(vllm_cmd: str) -> int:
@@ -97,7 +99,12 @@ def _start_vllm_server(vllm_cmd: str, timeout: int = 120) -> subprocess.Popen:
97
99
 
98
100
  with open(f"/tmp/vllm_{port}.txt", "a") as log_file:
99
101
  process = subprocess.Popen(
100
- cleaned_cmd.split(), stdout=log_file, stderr=subprocess.STDOUT, text=True, preexec_fn=os.setsid, env=env
102
+ cleaned_cmd.split(),
103
+ stdout=log_file,
104
+ stderr=subprocess.STDOUT,
105
+ text=True,
106
+ preexec_fn=os.setsid,
107
+ env=env,
101
108
  )
102
109
 
103
110
  _VLLM_PROCESSES.append(process)
@@ -147,7 +154,9 @@ def _kill_vllm_on_port(port: int) -> bool:
147
154
  proc = psutil.Process(process.pid)
148
155
  cmdline = " ".join(proc.cmdline())
149
156
  if f"--port {port}" in cmdline or f"--port={port}" in cmdline:
150
- logger.info(f"Killing tracked VLLM process {process.pid} on port {port}")
157
+ logger.info(
158
+ f"Killing tracked VLLM process {process.pid} on port {port}"
159
+ )
151
160
  os.killpg(os.getpgid(process.pid), signal.SIGTERM)
152
161
  try:
153
162
  process.wait(timeout=5)
@@ -187,8 +196,12 @@ def _kill_vllm_on_port(port: int) -> bool:
187
196
  for proc in psutil.process_iter(["pid", "cmdline"]):
188
197
  try:
189
198
  cmdline = " ".join(proc.info["cmdline"] or [])
190
- if "vllm" in cmdline.lower() and (f"--port {port}" in cmdline or f"--port={port}" in cmdline):
191
- logger.info(f"Killing untracked VLLM process {proc.info['pid']} on port {port}")
199
+ if "vllm" in cmdline.lower() and (
200
+ f"--port {port}" in cmdline or f"--port={port}" in cmdline
201
+ ):
202
+ logger.info(
203
+ f"Killing untracked VLLM process {proc.info['pid']} on port {port}"
204
+ )
192
205
  proc.terminate()
193
206
  try:
194
207
  proc.wait(timeout=5)
@@ -255,7 +268,9 @@ def _is_server_running(port: int) -> bool:
255
268
  return False
256
269
 
257
270
 
258
- def get_base_client(client=None, cache: bool = True, api_key="abc", vllm_cmd=None, vllm_process=None) -> OpenAI:
271
+ def get_base_client(
272
+ client=None, cache: bool = True, api_key="abc", vllm_cmd=None, vllm_process=None
273
+ ) -> OpenAI:
259
274
  """Get OpenAI client from various inputs."""
260
275
  from llm_utils import MOpenAI
261
276
 
@@ -264,17 +279,21 @@ def get_base_client(client=None, cache: bool = True, api_key="abc", vllm_cmd=Non
264
279
  # Parse environment variables from command to get clean command for port extraction
265
280
  _, cleaned_cmd = _parse_env_vars_from_cmd(vllm_cmd)
266
281
  port = _extract_port_from_vllm_cmd(cleaned_cmd)
267
- return MOpenAI(base_url=f"http://localhost:{port}/v1", api_key=api_key, cache=cache)
268
- else:
269
- raise ValueError("Either client or vllm_cmd must be provided.")
270
- elif isinstance(client, int):
271
- return MOpenAI(base_url=f"http://localhost:{client}/v1", api_key=api_key, cache=cache)
272
- elif isinstance(client, str):
282
+ return MOpenAI(
283
+ base_url=f"http://localhost:{port}/v1", api_key=api_key, cache=cache
284
+ )
285
+ raise ValueError("Either client or vllm_cmd must be provided.")
286
+ if isinstance(client, int):
287
+ return MOpenAI(
288
+ base_url=f"http://localhost:{client}/v1", api_key=api_key, cache=cache
289
+ )
290
+ if isinstance(client, str):
273
291
  return MOpenAI(base_url=client, api_key=api_key, cache=cache)
274
- elif isinstance(client, OpenAI):
292
+ if isinstance(client, OpenAI):
275
293
  return MOpenAI(base_url=client.base_url, api_key=api_key, cache=cache)
276
- else:
277
- raise ValueError("Invalid client type. Must be OpenAI, port (int), base_url (str), or None.")
294
+ raise ValueError(
295
+ "Invalid client type. Must be OpenAI, port (int), base_url (str), or None."
296
+ )
278
297
 
279
298
 
280
299
  def _is_lora_path(path: str) -> bool:
@@ -285,7 +304,7 @@ def _is_lora_path(path: str) -> bool:
285
304
  return os.path.isfile(adapter_config_path)
286
305
 
287
306
 
288
- def _get_port_from_client(client: OpenAI) -> Optional[int]:
307
+ def _get_port_from_client(client: OpenAI) -> int | None:
289
308
  """Extract port from OpenAI client base_url."""
290
309
  if hasattr(client, "base_url") and client.base_url:
291
310
  base_url = str(client.base_url)
@@ -10,9 +10,11 @@ from datetime import datetime
10
10
 
11
11
  import aiohttp
12
12
  from loguru import logger
13
- from speedy_utils import setup_logger
14
13
  from tabulate import tabulate
15
14
 
15
+ from speedy_utils import setup_logger
16
+
17
+
16
18
  setup_logger(min_interval=5)
17
19
 
18
20
 
@@ -132,10 +134,9 @@ def format_uptime(start_time):
132
134
 
133
135
  if hours > 0:
134
136
  return f"{hours}h {minutes}m {seconds}s"
135
- elif minutes > 0:
137
+ if minutes > 0:
136
138
  return f"{minutes}m {seconds}s"
137
- else:
138
- return f"{seconds}s"
139
+ return f"{seconds}s"
139
140
 
140
141
 
141
142
  def print_banner():
@@ -247,13 +248,12 @@ async def check_server_health(session, host, port):
247
248
  )
248
249
  await response.release()
249
250
  return True
250
- else:
251
- logger.debug(
252
- f"[{LOAD_BALANCER_PORT=}] Health check failed for {url} (Status: {response.status})"
253
- )
254
- await response.release()
255
- return False
256
- except asyncio.TimeoutError:
251
+ logger.debug(
252
+ f"[{LOAD_BALANCER_PORT=}] Health check failed for {url} (Status: {response.status})"
253
+ )
254
+ await response.release()
255
+ return False
256
+ except TimeoutError:
257
257
  logger.debug(f"Health check HTTP request timeout for {url}")
258
258
  return False
259
259
  except aiohttp.ClientConnectorError as e:
@@ -311,11 +311,11 @@ async def scan_and_update_servers():
311
311
 
312
312
  if added:
313
313
  logger.info(
314
- f"Servers added (passed /health check): {sorted(list(added))}"
314
+ f"Servers added (passed /health check): {sorted(added)}"
315
315
  )
316
316
  if removed:
317
317
  logger.info(
318
- f"Servers removed (failed /health check or stopped): {sorted(list(removed))}"
318
+ f"Servers removed (failed /health check or stopped): {sorted(removed)}"
319
319
  )
320
320
  for server in removed:
321
321
  if server in connection_counts:
@@ -329,7 +329,7 @@ async def scan_and_update_servers():
329
329
  f"Removed throttling timestamp for unavailable server {server}"
330
330
  )
331
331
 
332
- available_servers = sorted(list(current_set))
332
+ available_servers = sorted(current_set)
333
333
  for server in available_servers:
334
334
  if server not in connection_counts:
335
335
  connection_counts[server] = 0
@@ -375,7 +375,9 @@ async def handle_client(client_reader, client_writer):
375
375
 
376
376
  min_connections = float("inf")
377
377
  least_used_available_servers = []
378
- for server in (
378
+ for (
379
+ server
380
+ ) in (
379
381
  available_servers
380
382
  ): # Iterate only over servers that passed health check
381
383
  count = connection_counts.get(server, 0)
@@ -705,9 +707,9 @@ async def stats_json(request):
705
707
  {
706
708
  "host": BACKEND_HOST,
707
709
  "port": port,
708
- "active_connections": connection_counts.get(server, 0)
709
- if is_online
710
- else 0,
710
+ "active_connections": (
711
+ connection_counts.get(server, 0) if is_online else 0
712
+ ),
711
713
  "status": "ONLINE" if is_online else "OFFLINE",
712
714
  }
713
715
  )
@@ -929,23 +931,14 @@ async def main():
929
931
  logger.info("Cancelling background tasks...")
930
932
  scan_task.cancel()
931
933
  status_task.cancel()
932
- try:
934
+ with contextlib.suppress(asyncio.CancelledError):
933
935
  await asyncio.gather(scan_task, status_task, return_exceptions=True)
934
- except asyncio.CancelledError:
935
- pass
936
936
  print(f"{Colors.BRIGHT_GREEN}✅ Shutdown complete. Goodbye!{Colors.RESET}")
937
937
  logger.info("Background tasks finished.")
938
938
 
939
939
 
940
940
  def run_load_balancer():
941
- global \
942
- LOAD_BALANCER_PORT, \
943
- BACKEND_PORTS, \
944
- BACKEND_HOST, \
945
- STATUS_PRINT_INTERVAL, \
946
- HEALTH_CHECK_TIMEOUT, \
947
- THROTTLE_MS, \
948
- STATS_PORT
941
+ global LOAD_BALANCER_PORT, BACKEND_PORTS, BACKEND_HOST, STATUS_PRINT_INTERVAL, HEALTH_CHECK_TIMEOUT, THROTTLE_MS, STATS_PORT
949
942
  args = parse_args()
950
943
  LOAD_BALANCER_PORT = args.port
951
944
  BACKEND_HOST = args.host
@@ -976,4 +969,4 @@ def run_load_balancer():
976
969
 
977
970
 
978
971
  if __name__ == "__main__":
979
- run_load_balancer()
972
+ run_load_balancer()
@@ -75,6 +75,7 @@ from loguru import logger
75
75
  from llm_utils.lm.openai_memoize import MOpenAI
76
76
  from speedy_utils.common.utils_io import load_by_ext
77
77
 
78
+
78
79
  LORA_DIR: str = os.environ.get("LORA_DIR", "/loras")
79
80
  LORA_DIR = os.path.abspath(LORA_DIR)
80
81
  HF_HOME: str = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
@@ -93,8 +94,8 @@ def add_lora(
93
94
  lora_name_or_path: str,
94
95
  host_port: str,
95
96
  url: str = "http://HOST:PORT/v1/load_lora_adapter",
96
- served_model_name: Optional[str] = None,
97
- lora_module: Optional[str] = None,
97
+ served_model_name: str | None = None,
98
+ lora_module: str | None = None,
98
99
  ) -> dict:
99
100
  """Add a LoRA adapter to a running vLLM server."""
100
101
  url = url.replace("HOST:PORT", host_port)
@@ -126,7 +127,7 @@ def add_lora(
126
127
  return {"error": f"Request failed: {str(e)}"}
127
128
 
128
129
 
129
- def unload_lora(lora_name: str, host_port: str) -> Optional[dict]:
130
+ def unload_lora(lora_name: str, host_port: str) -> dict | None:
130
131
  """Unload a LoRA adapter from a running vLLM server."""
131
132
  try:
132
133
  url = f"http://{host_port}/v1/unload_lora_adapter"
@@ -144,7 +145,7 @@ def unload_lora(lora_name: str, host_port: str) -> Optional[dict]:
144
145
  def serve(args) -> None:
145
146
  """Start vLLM containers with dynamic args."""
146
147
  print("Starting vLLM containers...,")
147
- gpu_groups_arr: List[str] = args.gpu_groups.split(",")
148
+ gpu_groups_arr: list[str] = args.gpu_groups.split(",")
148
149
  vllm_binary: str = get_vllm()
149
150
  if args.enable_lora:
150
151
  vllm_binary = "VLLM_ALLOW_RUNTIME_LORA_UPDATING=True " + vllm_binary
@@ -232,9 +233,9 @@ def get_vllm() -> str:
232
233
  vllm_binary = subprocess.check_output("which vllm", shell=True, text=True).strip()
233
234
  vllm_binary = os.getenv("VLLM_BINARY", vllm_binary)
234
235
  logger.info(f"vLLM binary: {vllm_binary}")
235
- assert os.path.exists(vllm_binary), (
236
- f"vLLM binary not found at {vllm_binary}, please set VLLM_BINARY env variable"
237
- )
236
+ assert os.path.exists(
237
+ vllm_binary
238
+ ), f"vLLM binary not found at {vllm_binary}, please set VLLM_BINARY env variable"
238
239
  return vllm_binary
239
240
 
240
241
 
@@ -11,15 +11,21 @@ Example:
11
11
  # Using local model
12
12
  cache = VectorCache("Qwen/Qwen3-Embedding-0.6B")
13
13
  embeddings = cache.embeds(["Hello world", "How are you?"])
14
-
14
+
15
15
  # Using OpenAI API
16
16
  cache = VectorCache("https://api.openai.com/v1")
17
17
  embeddings = cache.embeds(["Hello world", "How are you?"])
18
18
  """
19
19
 
20
20
  from .core import VectorCache
21
- from .utils import get_default_cache_path, validate_model_name, estimate_cache_size
21
+ from .utils import estimate_cache_size, get_default_cache_path, validate_model_name
22
+
22
23
 
23
24
  __version__ = "0.1.0"
24
25
  __author__ = "AnhVTH <anhvth.226@gmail.com>"
25
- __all__ = ["VectorCache", "get_default_cache_path", "validate_model_name", "estimate_cache_size"]
26
+ __all__ = [
27
+ "VectorCache",
28
+ "get_default_cache_path",
29
+ "validate_model_name",
30
+ "estimate_cache_size",
31
+ ]
@@ -106,7 +106,7 @@ def handle_embed(args):
106
106
  if not file_path.exists():
107
107
  raise FileNotFoundError(f"File not found: {args.file}")
108
108
 
109
- with open(file_path, "r", encoding="utf-8") as f:
109
+ with open(file_path, encoding="utf-8") as f:
110
110
  texts.extend([line.strip() for line in f if line.strip()])
111
111
 
112
112
  if not texts:
@@ -61,18 +61,18 @@ class VectorCache:
61
61
  def __init__(
62
62
  self,
63
63
  url_or_model: str,
64
- backend: Optional[Literal["vllm", "transformers", "openai"]] = None,
65
- embed_size: Optional[int] = None,
66
- db_path: Optional[str] = None,
64
+ backend: Literal["vllm", "transformers", "openai"] | None = None,
65
+ embed_size: int | None = None,
66
+ db_path: str | None = None,
67
67
  # OpenAI API parameters
68
- api_key: Optional[str] = "abc",
69
- model_name: Optional[str] = None,
68
+ api_key: str | None = "abc",
69
+ model_name: str | None = None,
70
70
  # vLLM parameters
71
71
  vllm_gpu_memory_utilization: float = 0.5,
72
72
  vllm_tensor_parallel_size: int = 1,
73
73
  vllm_dtype: str = "auto",
74
74
  vllm_trust_remote_code: bool = False,
75
- vllm_max_model_len: Optional[int] = None,
75
+ vllm_max_model_len: int | None = None,
76
76
  # Transformers parameters
77
77
  transformers_device: str = "auto",
78
78
  transformers_batch_size: int = 32,
@@ -149,7 +149,6 @@ class VectorCache:
149
149
  if self.verbose:
150
150
  print(f"Model auto-detection failed: {e}, using default model")
151
151
  # Fallback to default if auto-detection fails
152
- pass
153
152
 
154
153
  # Set default db_path if not provided
155
154
  if db_path is None:
@@ -185,7 +184,7 @@ class VectorCache:
185
184
  print(f"✓ {self.backend.upper()} model/client loaded successfully")
186
185
 
187
186
  def _determine_backend(
188
- self, backend: Optional[Literal["vllm", "transformers", "openai"]]
187
+ self, backend: Literal["vllm", "transformers", "openai"] | None
189
188
  ) -> str:
190
189
  """Determine the appropriate backend based on url_or_model and user preference."""
191
190
  if backend is not None:
@@ -202,7 +201,7 @@ class VectorCache:
202
201
  # Default to vllm for local models
203
202
  return "vllm"
204
203
 
205
- def _try_infer_model_name(self, model_name: Optional[str]) -> Optional[str]:
204
+ def _try_infer_model_name(self, model_name: str | None) -> str | None:
206
205
  """Infer model name for OpenAI backend if not explicitly provided."""
207
206
  if model_name:
208
207
  return model_name
@@ -243,17 +242,21 @@ class VectorCache:
243
242
  ) # Checkpoint WAL every 1000 pages
244
243
 
245
244
  def _ensure_schema(self) -> None:
246
- self.conn.execute("""
245
+ self.conn.execute(
246
+ """
247
247
  CREATE TABLE IF NOT EXISTS cache (
248
248
  hash TEXT PRIMARY KEY,
249
249
  text TEXT,
250
250
  embedding BLOB
251
251
  )
252
- """)
252
+ """
253
+ )
253
254
  # Add index for faster lookups if it doesn't exist
254
- self.conn.execute("""
255
+ self.conn.execute(
256
+ """
255
257
  CREATE INDEX IF NOT EXISTS idx_cache_hash ON cache(hash)
256
- """)
258
+ """
259
+ )
257
260
  self.conn.commit()
258
261
 
259
262
  def _load_openai_client(self) -> None:
@@ -275,7 +278,7 @@ class VectorCache:
275
278
  tensor_parallel_size = cast(int, self.config["vllm_tensor_parallel_size"])
276
279
  dtype = cast(str, self.config["vllm_dtype"])
277
280
  trust_remote_code = cast(bool, self.config["vllm_trust_remote_code"])
278
- max_model_len = cast(Optional[int], self.config["vllm_max_model_len"])
281
+ max_model_len = cast(int | None, self.config["vllm_max_model_len"])
279
282
 
280
283
  vllm_kwargs = {
281
284
  "model": self.url_or_model,
@@ -312,8 +315,7 @@ class VectorCache:
312
315
  f"4. Ensure no other processes are using GPU memory during initialization\n"
313
316
  f"Original error: {e}"
314
317
  ) from e
315
- else:
316
- raise
318
+ raise
317
319
  elif self.backend == "transformers":
318
320
  import torch # type: ignore[import-not-found] # noqa: F401
319
321
  from transformers import ( # type: ignore[import-not-found]
@@ -345,29 +347,28 @@ class VectorCache:
345
347
  def _get_embeddings(self, texts: list[str]) -> list[list[float]]:
346
348
  """Get embeddings using the configured backend."""
347
349
  assert isinstance(texts, list), "texts must be a list"
348
- assert all(isinstance(t, str) for t in texts), (
349
- "all elements in texts must be strings"
350
- )
350
+ assert all(
351
+ isinstance(t, str) for t in texts
352
+ ), "all elements in texts must be strings"
351
353
  if self.backend == "openai":
352
354
  return self._get_openai_embeddings(texts)
353
- elif self.backend == "vllm":
355
+ if self.backend == "vllm":
354
356
  return self._get_vllm_embeddings(texts)
355
- elif self.backend == "transformers":
357
+ if self.backend == "transformers":
356
358
  return self._get_transformers_embeddings(texts)
357
- else:
358
- raise ValueError(f"Unsupported backend: {self.backend}")
359
+ raise ValueError(f"Unsupported backend: {self.backend}")
359
360
 
360
361
  def _get_openai_embeddings(self, texts: list[str]) -> list[list[float]]:
361
362
  """Get embeddings using OpenAI API."""
362
363
  assert isinstance(texts, list), "texts must be a list"
363
- assert all(isinstance(t, str) for t in texts), (
364
- "all elements in texts must be strings"
365
- )
364
+ assert all(
365
+ isinstance(t, str) for t in texts
366
+ ), "all elements in texts must be strings"
366
367
  # Assert valid model_name for OpenAI backend
367
368
  model_name = self.config["model_name"]
368
- assert model_name is not None and model_name.strip(), (
369
- f"Invalid model_name for OpenAI backend: {model_name}. Model name must be provided and non-empty."
370
- )
369
+ assert (
370
+ model_name is not None and model_name.strip()
371
+ ), f"Invalid model_name for OpenAI backend: {model_name}. Model name must be provided and non-empty."
371
372
 
372
373
  if self._client is None:
373
374
  if self.verbose:
@@ -385,9 +386,9 @@ class VectorCache:
385
386
  def _get_vllm_embeddings(self, texts: list[str]) -> list[list[float]]:
386
387
  """Get embeddings using vLLM."""
387
388
  assert isinstance(texts, list), "texts must be a list"
388
- assert all(isinstance(t, str) for t in texts), (
389
- "all elements in texts must be strings"
390
- )
389
+ assert all(
390
+ isinstance(t, str) for t in texts
391
+ ), "all elements in texts must be strings"
391
392
  if self._model is None:
392
393
  if self.verbose:
393
394
  print("🔧 Loading vLLM model...")
@@ -402,9 +403,9 @@ class VectorCache:
402
403
  def _get_transformers_embeddings(self, texts: list[str]) -> list[list[float]]:
403
404
  """Get embeddings using transformers directly."""
404
405
  assert isinstance(texts, list), "texts must be a list"
405
- assert all(isinstance(t, str) for t in texts), (
406
- "all elements in texts must be strings"
407
- )
406
+ assert all(
407
+ isinstance(t, str) for t in texts
408
+ ), "all elements in texts must be strings"
408
409
  if self._model is None:
409
410
  if self.verbose:
410
411
  print("🔧 Loading Transformers model...")
@@ -464,13 +465,12 @@ class VectorCache:
464
465
  left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
465
466
  if left_padding:
466
467
  return last_hidden_states[:, -1]
467
- else:
468
- sequence_lengths = attention_mask.sum(dim=1) - 1
469
- batch_size = last_hidden_states.shape[0]
470
- return last_hidden_states[
471
- torch.arange(batch_size, device=last_hidden_states.device),
472
- sequence_lengths,
473
- ]
468
+ sequence_lengths = attention_mask.sum(dim=1) - 1
469
+ batch_size = last_hidden_states.shape[0]
470
+ return last_hidden_states[
471
+ torch.arange(batch_size, device=last_hidden_states.device),
472
+ sequence_lengths,
473
+ ]
474
474
 
475
475
  def _hash_text(self, text: str) -> str:
476
476
  return hashlib.sha1(text.encode("utf-8")).hexdigest()
@@ -486,8 +486,7 @@ class VectorCache:
486
486
  try:
487
487
  if params is None:
488
488
  return self.conn.execute(query)
489
- else:
490
- return self.conn.execute(query, params)
489
+ return self.conn.execute(query, params)
491
490
 
492
491
  except sqlite3.OperationalError as e:
493
492
  last_exception = e
@@ -502,9 +501,8 @@ class VectorCache:
502
501
 
503
502
  time.sleep(delay)
504
503
  continue
505
- else:
506
- # Re-raise if not a lock error or max retries exceeded
507
- raise
504
+ # Re-raise if not a lock error or max retries exceeded
505
+ raise
508
506
  except Exception:
509
507
  # Re-raise any other exceptions
510
508
  raise
@@ -524,9 +522,9 @@ class VectorCache:
524
522
  computing missing embeddings.
525
523
  """
526
524
  assert isinstance(texts, list), "texts must be a list"
527
- assert all(isinstance(t, str) for t in texts), (
528
- "all elements in texts must be strings"
529
- )
525
+ assert all(
526
+ isinstance(t, str) for t in texts
527
+ ), "all elements in texts must be strings"
530
528
  if not texts:
531
529
  return np.empty((0, 0), dtype=np.float32)
532
530
  t = time()
@@ -554,11 +552,11 @@ class VectorCache:
554
552
  # Determine which texts are missing
555
553
  if cache:
556
554
  missing_items: list[tuple[str, str]] = [
557
- (t, h) for t, h in zip(texts, hashes) if h not in hit_map
555
+ (t, h) for t, h in zip(texts, hashes, strict=False) if h not in hit_map
558
556
  ]
559
557
  else:
560
558
  missing_items: list[tuple[str, str]] = [
561
- (t, h) for t, h in zip(texts, hashes)
559
+ (t, h) for t, h in zip(texts, hashes, strict=False)
562
560
  ]
563
561
 
564
562
  if missing_items:
@@ -608,7 +606,7 @@ class VectorCache:
608
606
 
609
607
  # Prepare batch data for immediate insert
610
608
  batch_data: list[tuple[str, str, bytes]] = []
611
- for (text, h), vec in zip(batch_items, batch_embeds):
609
+ for (text, h), vec in zip(batch_items, batch_embeds, strict=False):
612
610
  arr = np.asarray(vec, dtype=np.float32)
613
611
  batch_data.append((h, text, arr.tobytes()))
614
612
  hit_map[h] = arr
@@ -640,9 +638,9 @@ class VectorCache:
640
638
 
641
639
  def __call__(self, texts: list[str], cache: bool = True) -> np.ndarray:
642
640
  assert isinstance(texts, list), "texts must be a list"
643
- assert all(isinstance(t, str) for t in texts), (
644
- "all elements in texts must be strings"
645
- )
641
+ assert all(
642
+ isinstance(t, str) for t in texts
643
+ ), "all elements in texts must be strings"
646
644
  return self.embeds(texts, cache)
647
645
 
648
646
  def _bulk_insert(self, data: list[tuple[str, str, bytes]]) -> None:
@@ -662,7 +660,7 @@ class VectorCache:
662
660
 
663
661
  for attempt in range(max_retries + 1):
664
662
  try:
665
- cursor = self.conn.executemany(
663
+ self.conn.executemany(
666
664
  "INSERT OR IGNORE INTO cache (hash, text, embedding) VALUES (?, ?, ?)",
667
665
  data,
668
666
  )
@@ -688,9 +686,8 @@ class VectorCache:
688
686
 
689
687
  time.sleep(delay)
690
688
  continue
691
- else:
692
- # Re-raise if not a lock error or max retries exceeded
693
- raise
689
+ # Re-raise if not a lock error or max retries exceeded
690
+ raise
694
691
  except Exception:
695
692
  # Re-raise any other exceptions
696
693
  raise
@@ -723,12 +720,11 @@ class VectorCache:
723
720
 
724
721
  time.sleep(delay)
725
722
  continue
726
- else:
727
- raise
723
+ raise
728
724
  except Exception:
729
725
  raise
730
726
 
731
- def get_config(self) -> Dict[str, Any]:
727
+ def get_config(self) -> dict[str, Any]:
732
728
  """Get current configuration."""
733
729
  return {
734
730
  "url_or_model": self.url_or_model,
@@ -1,15 +1,17 @@
1
1
  """Type definitions for the embed_cache package."""
2
2
 
3
- from typing import List, Dict, Any, Union, Optional, Tuple
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
4
5
  import numpy as np
5
6
  from numpy.typing import NDArray
6
7
 
8
+
7
9
  # Type aliases
8
- TextList = List[str]
10
+ TextList = list[str]
9
11
  EmbeddingArray = NDArray[np.float32]
10
- EmbeddingList = List[List[float]]
11
- CacheStats = Dict[str, int]
12
+ EmbeddingList = list[list[float]]
13
+ CacheStats = dict[str, int]
12
14
  ModelIdentifier = str # Either URL or model name/path
13
15
 
14
16
  # For backwards compatibility
15
- Embeddings = Union[EmbeddingArray, EmbeddingList]
17
+ Embeddings = Union[EmbeddingArray, EmbeddingList]