speedy-utils 1.1.12__tar.gz → 1.1.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/PKG-INFO +2 -1
  2. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/pyproject.toml +2 -1
  3. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/__init__.py +4 -1
  4. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/lm/async_lm/async_lm.py +2 -1
  5. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/lm/async_lm/async_lm_base.py +6 -6
  6. speedy_utils-1.1.14/src/llm_utils/vector_cache/__init__.py +25 -0
  7. speedy_utils-1.1.14/src/llm_utils/vector_cache/cli.py +200 -0
  8. speedy_utils-1.1.14/src/llm_utils/vector_cache/core.py +557 -0
  9. speedy_utils-1.1.14/src/llm_utils/vector_cache/types.py +15 -0
  10. speedy_utils-1.1.14/src/llm_utils/vector_cache/utils.py +42 -0
  11. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/__init__.py +1 -1
  12. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/common/logger.py +2 -2
  13. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/common/utils_io.py +2 -2
  14. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/common/utils_print.py +5 -5
  15. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/multi_worker/process.py +4 -4
  16. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/multi_worker/thread.py +4 -4
  17. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/README.md +0 -0
  18. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/chat_format/__init__.py +0 -0
  19. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/chat_format/display.py +0 -0
  20. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/chat_format/transform.py +0 -0
  21. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/chat_format/utils.py +0 -0
  22. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/group_messages.py +0 -0
  23. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/lm/__init__.py +0 -0
  24. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/lm/async_lm/__init__.py +0 -0
  25. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/lm/async_lm/_utils.py +0 -0
  26. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/lm/async_lm/async_llm_task.py +0 -0
  27. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/lm/async_lm/lm_specific.py +0 -0
  28. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/lm/openai_memoize.py +0 -0
  29. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/lm/utils.py +0 -0
  30. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/scripts/README.md +0 -0
  31. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/scripts/vllm_load_balancer.py +0 -0
  32. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/llm_utils/scripts/vllm_serve.py +0 -0
  33. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/all.py +0 -0
  34. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/common/__init__.py +0 -0
  35. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/common/clock.py +0 -0
  36. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/common/function_decorator.py +0 -0
  37. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/common/notebook_utils.py +0 -0
  38. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/common/report_manager.py +0 -0
  39. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/common/utils_cache.py +0 -0
  40. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/common/utils_misc.py +0 -0
  41. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/multi_worker/__init__.py +0 -0
  42. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/scripts/__init__.py +0 -0
  43. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/scripts/mpython.py +0 -0
  44. {speedy_utils-1.1.12 → speedy_utils-1.1.14}/src/speedy_utils/scripts/openapi_client_codegen.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: speedy-utils
3
- Version: 1.1.12
3
+ Version: 1.1.14
4
4
  Summary: Fast and easy-to-use package for data science
5
5
  Author: AnhVTH
6
6
  Author-email: anhvth.226@gmail.com
@@ -25,6 +25,7 @@ Requires-Dist: jupyterlab
25
25
  Requires-Dist: loguru
26
26
  Requires-Dist: matplotlib
27
27
  Requires-Dist: numpy
28
+ Requires-Dist: openai (>=1.106.0,<2.0.0)
28
29
  Requires-Dist: packaging (>=23.2,<25)
29
30
  Requires-Dist: pandas
30
31
  Requires-Dist: pydantic
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "speedy-utils"
3
- version = "1.1.12"
3
+ version = "1.1.14"
4
4
  description = "Fast and easy-to-use package for data science"
5
5
  authors = ["AnhVTH <anhvth.226@gmail.com>"]
6
6
  readme = "README.md"
@@ -58,6 +58,7 @@ json-repair = ">=0.25.0,<0.31.0"
58
58
  fastprogress = "*"
59
59
  freezegun = "^1.5.1"
60
60
  packaging = ">=23.2,<25"
61
+ openai = "^1.106.0"
61
62
 
62
63
  [tool.poetry.scripts]
63
64
  mpython = "speedy_utils.scripts.mpython:main"
@@ -1,4 +1,6 @@
1
1
  from llm_utils.lm.openai_memoize import MOpenAI
2
+ from llm_utils.vector_cache import VectorCache
3
+
2
4
  from .chat_format import (
3
5
  build_chatml_input,
4
6
  display_chat_messages_as_html,
@@ -24,5 +26,6 @@ __all__ = [
24
26
  "display_chat_messages_as_html",
25
27
  "AsyncLM",
26
28
  "AsyncLLMTask",
27
- "MOpenAI"
29
+ "MOpenAI",
30
+ "VectorCache",
28
31
  ]
@@ -5,6 +5,7 @@ from typing import (
5
5
  Literal,
6
6
  Optional,
7
7
  Type,
8
+ Union,
8
9
  cast,
9
10
  )
10
11
 
@@ -49,7 +50,7 @@ class AsyncLM(AsyncLMBase):
49
50
  temperature: float = 0.0,
50
51
  max_tokens: int = 2_000,
51
52
  host: str = "localhost",
52
- port: Optional[int | str] = None,
53
+ port: Optional[Union[int, str]] = None,
53
54
  base_url: Optional[str] = None,
54
55
  api_key: Optional[str] = None,
55
56
  cache: bool = True,
@@ -41,7 +41,7 @@ class AsyncLMBase:
41
41
  self,
42
42
  *,
43
43
  host: str = "localhost",
44
- port: Optional[int | str] = None,
44
+ port: Optional[Union[int, str]] = None,
45
45
  base_url: Optional[str] = None,
46
46
  api_key: Optional[str] = None,
47
47
  cache: bool = True,
@@ -81,8 +81,8 @@ class AsyncLMBase:
81
81
  async def __call__( # type: ignore
82
82
  self,
83
83
  *,
84
- prompt: str | None = ...,
85
- messages: RawMsgs | None = ...,
84
+ prompt: Optional[str] = ...,
85
+ messages: Optional[RawMsgs] = ...,
86
86
  response_format: type[str] = str,
87
87
  return_openai_response: bool = ...,
88
88
  **kwargs: Any,
@@ -92,8 +92,8 @@ class AsyncLMBase:
92
92
  async def __call__(
93
93
  self,
94
94
  *,
95
- prompt: str | None = ...,
96
- messages: RawMsgs | None = ...,
95
+ prompt: Optional[str] = ...,
96
+ messages: Optional[RawMsgs] = ...,
97
97
  response_format: Type[TModel],
98
98
  return_openai_response: bool = ...,
99
99
  **kwargs: Any,
@@ -137,7 +137,7 @@ class AsyncLMBase:
137
137
  @staticmethod
138
138
  def _parse_output(
139
139
  raw_response: Any, response_format: Union[type[str], Type[BaseModel]]
140
- ) -> str | BaseModel:
140
+ ) -> Union[str, BaseModel]:
141
141
  if hasattr(raw_response, "model_dump"):
142
142
  raw_response = raw_response.model_dump()
143
143
 
@@ -0,0 +1,25 @@
1
+ """
2
+ Efficient embedding caching system using vLLM for offline embeddings.
3
+
4
+ This package provides a fast, SQLite-backed caching layer for text embeddings,
5
+ supporting both OpenAI API and local models via vLLM.
6
+
7
+ Classes:
8
+ VectorCache: Main class for embedding computation and caching
9
+
10
+ Example:
11
+ # Using local model
12
+ cache = VectorCache("Qwen/Qwen3-Embedding-0.6B")
13
+ embeddings = cache.embeds(["Hello world", "How are you?"])
14
+
15
+ # Using OpenAI API
16
+ cache = VectorCache("https://api.openai.com/v1")
17
+ embeddings = cache.embeds(["Hello world", "How are you?"])
18
+ """
19
+
20
+ from .core import VectorCache
21
+ from .utils import get_default_cache_path, validate_model_name, estimate_cache_size
22
+
23
+ __version__ = "0.1.0"
24
+ __author__ = "AnhVTH <anhvth.226@gmail.com>"
25
+ __all__ = ["VectorCache", "get_default_cache_path", "validate_model_name", "estimate_cache_size"]
@@ -0,0 +1,200 @@
1
+ #!/usr/bin/env python3
2
+ """Command-line interface for embed_cache package."""
3
+
4
+ import argparse
5
+ import json
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import List
9
+
10
+ from llm_utils.vector_cache import VectorCache, estimate_cache_size, validate_model_name
11
+
12
+
13
+ def main():
14
+ """Main CLI entry point."""
15
+ parser = argparse.ArgumentParser(
16
+ description="Command-line interface for embed_cache package"
17
+ )
18
+
19
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
20
+
21
+ # Embed command
22
+ embed_parser = subparsers.add_parser("embed", help="Generate embeddings for texts")
23
+ embed_parser.add_argument("model", help="Model name or API URL")
24
+ embed_parser.add_argument("--texts", nargs="+", help="Texts to embed")
25
+ embed_parser.add_argument("--file", help="File containing texts (one per line)")
26
+ embed_parser.add_argument("--output", help="Output file for embeddings (JSON)")
27
+ embed_parser.add_argument(
28
+ "--cache-db", default="embed_cache.sqlite", help="Cache database path"
29
+ )
30
+ embed_parser.add_argument(
31
+ "--backend",
32
+ choices=["vllm", "transformers", "openai", "auto"],
33
+ default="auto",
34
+ help="Backend to use",
35
+ )
36
+ embed_parser.add_argument(
37
+ "--gpu-memory",
38
+ type=float,
39
+ default=0.5,
40
+ help="GPU memory utilization for vLLM (0.0-1.0)",
41
+ )
42
+ embed_parser.add_argument(
43
+ "--batch-size", type=int, default=32, help="Batch size for transformers"
44
+ )
45
+ embed_parser.add_argument(
46
+ "--verbose", action="store_true", help="Enable verbose output"
47
+ )
48
+
49
+ # Cache stats command
50
+ stats_parser = subparsers.add_parser("stats", help="Show cache statistics")
51
+ stats_parser.add_argument(
52
+ "--cache-db", default="embed_cache.sqlite", help="Cache database path"
53
+ )
54
+
55
+ # Clear cache command
56
+ clear_parser = subparsers.add_parser("clear", help="Clear cache")
57
+ clear_parser.add_argument(
58
+ "--cache-db", default="embed_cache.sqlite", help="Cache database path"
59
+ )
60
+ clear_parser.add_argument(
61
+ "--confirm", action="store_true", help="Skip confirmation prompt"
62
+ )
63
+
64
+ # Validate model command
65
+ validate_parser = subparsers.add_parser("validate", help="Validate model name")
66
+ validate_parser.add_argument("model", help="Model name to validate")
67
+
68
+ # Estimate command
69
+ estimate_parser = subparsers.add_parser("estimate", help="Estimate cache size")
70
+ estimate_parser.add_argument("num_texts", type=int, help="Number of texts")
71
+ estimate_parser.add_argument(
72
+ "--embed-dim", type=int, default=1024, help="Embedding dimension"
73
+ )
74
+
75
+ args = parser.parse_args()
76
+
77
+ if not args.command:
78
+ parser.print_help()
79
+ return
80
+
81
+ try:
82
+ if args.command == "embed":
83
+ handle_embed(args)
84
+ elif args.command == "stats":
85
+ handle_stats(args)
86
+ elif args.command == "clear":
87
+ handle_clear(args)
88
+ elif args.command == "validate":
89
+ handle_validate(args)
90
+ elif args.command == "estimate":
91
+ handle_estimate(args)
92
+ except Exception as e:
93
+ print(f"Error: {e}", file=sys.stderr)
94
+ sys.exit(1)
95
+
96
+
97
+ def handle_embed(args):
98
+ """Handle embed command."""
99
+ # Get texts
100
+ texts = []
101
+ if args.texts:
102
+ texts.extend(args.texts)
103
+
104
+ if args.file:
105
+ file_path = Path(args.file)
106
+ if not file_path.exists():
107
+ raise FileNotFoundError(f"File not found: {args.file}")
108
+
109
+ with open(file_path, "r", encoding="utf-8") as f:
110
+ texts.extend([line.strip() for line in f if line.strip()])
111
+
112
+ if not texts:
113
+ raise ValueError("No texts provided. Use --texts or --file")
114
+
115
+ print(f"Embedding {len(texts)} texts using model: {args.model}")
116
+
117
+ # Initialize cache and get embeddings
118
+ cache = VectorCache(args.model, db_path=args.cache_db)
119
+ embeddings = cache.embeds(texts)
120
+
121
+ print(f"Generated embeddings with shape: {embeddings.shape}")
122
+
123
+ # Output results
124
+ if args.output:
125
+ output_data = {
126
+ "texts": texts,
127
+ "embeddings": embeddings.tolist(),
128
+ "shape": list(embeddings.shape),
129
+ "model": args.model,
130
+ }
131
+
132
+ with open(args.output, "w", encoding="utf-8") as f:
133
+ json.dump(output_data, f, indent=2)
134
+
135
+ print(f"Results saved to: {args.output}")
136
+ else:
137
+ print(f"Embeddings shape: {embeddings.shape}")
138
+ print(f"Sample embedding (first 5 dims): {embeddings[0][:5].tolist()}")
139
+
140
+
141
+ def handle_stats(args):
142
+ """Handle stats command."""
143
+ cache_path = Path(args.cache_db)
144
+ if not cache_path.exists():
145
+ print(f"Cache database not found: {args.cache_db}")
146
+ return
147
+
148
+ cache = VectorCache("dummy", db_path=args.cache_db)
149
+ stats = cache.get_cache_stats()
150
+
151
+ print("Cache Statistics:")
152
+ print(f" Database: {args.cache_db}")
153
+ print(f" Total cached embeddings: {stats['total_cached']}")
154
+ print(f" Database size: {cache_path.stat().st_size / (1024 * 1024):.2f} MB")
155
+
156
+
157
+ def handle_clear(args):
158
+ """Handle clear command."""
159
+ cache_path = Path(args.cache_db)
160
+ if not cache_path.exists():
161
+ print(f"Cache database not found: {args.cache_db}")
162
+ return
163
+
164
+ if not args.confirm:
165
+ response = input(
166
+ f"Are you sure you want to clear cache at {args.cache_db}? [y/N]: "
167
+ )
168
+ if response.lower() != "y":
169
+ print("Cancelled.")
170
+ return
171
+
172
+ cache = VectorCache("dummy", db_path=args.cache_db)
173
+ stats_before = cache.get_cache_stats()
174
+ cache.clear_cache()
175
+
176
+ print(
177
+ f"Cleared {stats_before['total_cached']} cached embeddings from {args.cache_db}"
178
+ )
179
+
180
+
181
+ def handle_validate(args):
182
+ """Handle validate command."""
183
+ is_valid = validate_model_name(args.model)
184
+ if is_valid:
185
+ print(f"✓ Valid model: {args.model}")
186
+ else:
187
+ print(f"✗ Invalid model: {args.model}")
188
+ sys.exit(1)
189
+
190
+
191
+ def handle_estimate(args):
192
+ """Handle estimate command."""
193
+ size_estimate = estimate_cache_size(args.num_texts, args.embed_dim)
194
+ print(
195
+ f"Estimated cache size for {args.num_texts} texts ({args.embed_dim}D embeddings): {size_estimate}"
196
+ )
197
+
198
+
199
+ if __name__ == "__main__":
200
+ main()
@@ -0,0 +1,557 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import os
5
+ import sqlite3
6
+ from pathlib import Path
7
+ from time import time
8
+ from typing import Any, Dict, Literal, Optional, cast
9
+
10
+ import numpy as np
11
+
12
+
13
+ class VectorCache:
14
+ """
15
+ A caching layer for text embeddings with support for multiple backends.
16
+
17
+ Examples:
18
+ # OpenAI API
19
+ from llm_utils import VectorCache
20
+ cache = VectorCache("https://api.openai.com/v1", api_key="your-key")
21
+ embeddings = cache.embeds(["Hello world", "How are you?"])
22
+
23
+ # Custom OpenAI-compatible server (auto-detects model)
24
+ cache = VectorCache("http://localhost:8000/v1", api_key="abc")
25
+
26
+ # Transformers (Sentence Transformers)
27
+ cache = VectorCache("sentence-transformers/all-MiniLM-L6-v2")
28
+
29
+ # vLLM (local model)
30
+ cache = VectorCache("/path/to/model")
31
+
32
+ # Explicit backend specification
33
+ cache = VectorCache("model-name", backend="transformers")
34
+
35
+ # Lazy loading (default: True) - load model only when needed
36
+ cache = VectorCache("model-name", lazy=True)
37
+
38
+ # Eager loading - load model immediately
39
+ cache = VectorCache("model-name", lazy=False)
40
+ """
41
+ def __init__(
42
+ self,
43
+ url_or_model: str,
44
+ backend: Optional[Literal["vllm", "transformers", "openai"]] = None,
45
+ embed_size: Optional[int] = None,
46
+ db_path: Optional[str] = None,
47
+ # OpenAI API parameters
48
+ api_key: Optional[str] = "abc",
49
+ model_name: Optional[str] = None,
50
+ # vLLM parameters
51
+ vllm_gpu_memory_utilization: float = 0.5,
52
+ vllm_tensor_parallel_size: int = 1,
53
+ vllm_dtype: str = "auto",
54
+ vllm_trust_remote_code: bool = False,
55
+ vllm_max_model_len: Optional[int] = None,
56
+ # Transformers parameters
57
+ transformers_device: str = "auto",
58
+ transformers_batch_size: int = 32,
59
+ transformers_normalize_embeddings: bool = True,
60
+ transformers_trust_remote_code: bool = False,
61
+ # SQLite parameters
62
+ sqlite_chunk_size: int = 999,
63
+ sqlite_cache_size: int = 10000,
64
+ sqlite_mmap_size: int = 268435456,
65
+ # Other parameters
66
+ verbose: bool = True,
67
+ lazy: bool = True,
68
+ ) -> None:
69
+ self.url_or_model = url_or_model
70
+ self.embed_size = embed_size
71
+ self.verbose = verbose
72
+ self.lazy = lazy
73
+
74
+ self.backend = self._determine_backend(backend)
75
+ if self.verbose and backend is None:
76
+ print(f"Auto-detected backend: {self.backend}")
77
+
78
+ # Store all configuration parameters
79
+ self.config = {
80
+ # OpenAI
81
+ "api_key": api_key or os.getenv("OPENAI_API_KEY"),
82
+ "model_name": self._try_infer_model_name(model_name),
83
+ # vLLM
84
+ "vllm_gpu_memory_utilization": vllm_gpu_memory_utilization,
85
+ "vllm_tensor_parallel_size": vllm_tensor_parallel_size,
86
+ "vllm_dtype": vllm_dtype,
87
+ "vllm_trust_remote_code": vllm_trust_remote_code,
88
+ "vllm_max_model_len": vllm_max_model_len,
89
+ # Transformers
90
+ "transformers_device": transformers_device,
91
+ "transformers_batch_size": transformers_batch_size,
92
+ "transformers_normalize_embeddings": transformers_normalize_embeddings,
93
+ "transformers_trust_remote_code": transformers_trust_remote_code,
94
+ # SQLite
95
+ "sqlite_chunk_size": sqlite_chunk_size,
96
+ "sqlite_cache_size": sqlite_cache_size,
97
+ "sqlite_mmap_size": sqlite_mmap_size,
98
+ }
99
+
100
+ # Auto-detect model_name for OpenAI if using custom URL and default model
101
+ if (self.backend == "openai" and
102
+ model_name == "text-embedding-3-small" and
103
+ self.url_or_model != "https://api.openai.com/v1"):
104
+ if self.verbose:
105
+ print(f"Attempting to auto-detect model from {self.url_or_model}...")
106
+ try:
107
+ import openai
108
+ client = openai.OpenAI(
109
+ base_url=self.url_or_model,
110
+ api_key=self.config["api_key"]
111
+ )
112
+ models = client.models.list()
113
+ if models.data:
114
+ detected_model = models.data[0].id
115
+ self.config["model_name"] = detected_model
116
+ model_name = detected_model # Update for db_path computation
117
+ if self.verbose:
118
+ print(f"Auto-detected model: {detected_model}")
119
+ else:
120
+ if self.verbose:
121
+ print("No models found, using default model")
122
+ except Exception as e:
123
+ if self.verbose:
124
+ print(f"Model auto-detection failed: {e}, using default model")
125
+ # Fallback to default if auto-detection fails
126
+ pass
127
+
128
+ # Set default db_path if not provided
129
+ if db_path is None:
130
+ if self.backend == "openai":
131
+ model_id = self.config["model_name"] or "openai-default"
132
+ else:
133
+ model_id = self.url_or_model
134
+ safe_name = hashlib.sha1(model_id.encode("utf-8")).hexdigest()[:16]
135
+ self.db_path = Path.home() / ".cache" / "embed" / f"{self.backend}_{safe_name}.sqlite"
136
+ else:
137
+ self.db_path = Path(db_path)
138
+
139
+ # Ensure the directory exists
140
+ self.db_path.parent.mkdir(parents=True, exist_ok=True)
141
+
142
+ self.conn = sqlite3.connect(self.db_path)
143
+ self._optimize_connection()
144
+ self._ensure_schema()
145
+ self._model = None # Lazy loading
146
+ self._client = None # For OpenAI client
147
+
148
+ # Load model/client if not lazy
149
+ if not self.lazy:
150
+ if self.backend == "openai":
151
+ self._load_openai_client()
152
+ elif self.backend in ["vllm", "transformers"]:
153
+ self._load_model()
154
+
155
+ def _determine_backend(self, backend: Optional[Literal["vllm", "transformers", "openai"]]) -> str:
156
+ """Determine the appropriate backend based on url_or_model and user preference."""
157
+ if backend is not None:
158
+ valid_backends = ["vllm", "transformers", "openai"]
159
+ if backend not in valid_backends:
160
+ raise ValueError(f"Invalid backend '{backend}'. Must be one of: {valid_backends}")
161
+ return backend
162
+
163
+ if self.url_or_model.startswith("http"):
164
+ return "openai"
165
+
166
+ # Default to vllm for local models
167
+ return "vllm"
168
+ def _try_infer_model_name(self, model_name: Optional[str]) -> Optional[str]:
169
+ """Infer model name for OpenAI backend if not explicitly provided."""
170
+ # if self.backend != "openai":
171
+ # return model_name
172
+ if model_name:
173
+ return model_name
174
+ if 'https://' in self.url_or_model:
175
+ model_name = "text-embedding-3-small"
176
+ if 'http://localhost' in self.url_or_model:
177
+ from openai import OpenAI
178
+ client = OpenAI(base_url=self.url_or_model, api_key='abc')
179
+ model_name = client.models.list().data[0].id
180
+
181
+ # Default model name
182
+ print('Infer model name:', model_name)
183
+ return model_name
184
+ def _optimize_connection(self) -> None:
185
+ """Optimize SQLite connection for bulk operations."""
186
+ # Performance optimizations for bulk operations
187
+ self.conn.execute(
188
+ "PRAGMA journal_mode=WAL"
189
+ ) # Write-Ahead Logging for better concurrency
190
+ self.conn.execute("PRAGMA synchronous=NORMAL") # Faster writes, still safe
191
+ self.conn.execute(f"PRAGMA cache_size={self.config['sqlite_cache_size']}") # Configurable cache
192
+ self.conn.execute("PRAGMA temp_store=MEMORY") # Use memory for temp storage
193
+ self.conn.execute(f"PRAGMA mmap_size={self.config['sqlite_mmap_size']}") # Configurable memory mapping
194
+
195
+ def _ensure_schema(self) -> None:
196
+ self.conn.execute("""
197
+ CREATE TABLE IF NOT EXISTS cache (
198
+ hash TEXT PRIMARY KEY,
199
+ text TEXT,
200
+ embedding BLOB
201
+ )
202
+ """)
203
+ # Add index for faster lookups if it doesn't exist
204
+ self.conn.execute("""
205
+ CREATE INDEX IF NOT EXISTS idx_cache_hash ON cache(hash)
206
+ """)
207
+ self.conn.commit()
208
+
209
+ def _load_openai_client(self) -> None:
210
+ """Load OpenAI client."""
211
+ import openai
212
+ self._client = openai.OpenAI(
213
+ base_url=self.url_or_model,
214
+ api_key=self.config["api_key"]
215
+ )
216
+
217
+ def _load_model(self) -> None:
218
+ """Load the model for vLLM or Transformers."""
219
+ if self.backend == "vllm":
220
+ from vllm import LLM
221
+
222
+ gpu_memory_utilization = cast(float, self.config["vllm_gpu_memory_utilization"])
223
+ tensor_parallel_size = cast(int, self.config["vllm_tensor_parallel_size"])
224
+ dtype = cast(str, self.config["vllm_dtype"])
225
+ trust_remote_code = cast(bool, self.config["vllm_trust_remote_code"])
226
+ max_model_len = cast(Optional[int], self.config["vllm_max_model_len"])
227
+
228
+ vllm_kwargs = {
229
+ "model": self.url_or_model,
230
+ "task": "embed",
231
+ "gpu_memory_utilization": gpu_memory_utilization,
232
+ "tensor_parallel_size": tensor_parallel_size,
233
+ "dtype": dtype,
234
+ "trust_remote_code": trust_remote_code,
235
+ }
236
+
237
+ if max_model_len is not None:
238
+ vllm_kwargs["max_model_len"] = max_model_len
239
+
240
+ try:
241
+ self._model = LLM(**vllm_kwargs)
242
+ except (ValueError, AssertionError, RuntimeError) as e:
243
+ error_msg = str(e).lower()
244
+ if ("kv cache" in error_msg and "gpu_memory_utilization" in error_msg) or \
245
+ ("memory" in error_msg and ("gpu" in error_msg or "insufficient" in error_msg)) or \
246
+ ("free memory" in error_msg and "initial" in error_msg) or \
247
+ ("engine core initialization failed" in error_msg):
248
+ raise ValueError(
249
+ f"Insufficient GPU memory for vLLM model initialization. "
250
+ f"Current vllm_gpu_memory_utilization ({gpu_memory_utilization}) may be too low. "
251
+ f"Try one of the following:\n"
252
+ f"1. Increase vllm_gpu_memory_utilization (e.g., 0.5, 0.8, or 0.9)\n"
253
+ f"2. Decrease vllm_max_model_len (e.g., 4096, 8192)\n"
254
+ f"3. Use a smaller model\n"
255
+ f"4. Ensure no other processes are using GPU memory during initialization\n"
256
+ f"Original error: {e}"
257
+ ) from e
258
+ else:
259
+ raise
260
+ elif self.backend == "transformers":
261
+ from transformers import AutoTokenizer, AutoModel
262
+ import torch
263
+
264
+ device = self.config["transformers_device"]
265
+ # Handle "auto" device selection - default to CPU for transformers to avoid memory conflicts
266
+ if device == "auto":
267
+ device = "cpu" # Default to CPU to avoid GPU memory conflicts with vLLM
268
+
269
+ tokenizer = AutoTokenizer.from_pretrained(self.url_or_model, padding_side='left', trust_remote_code=self.config["transformers_trust_remote_code"])
270
+ model = AutoModel.from_pretrained(self.url_or_model, trust_remote_code=self.config["transformers_trust_remote_code"])
271
+
272
+ # Move model to device
273
+ model.to(device)
274
+ model.eval()
275
+
276
+ self._model = {"tokenizer": tokenizer, "model": model, "device": device}
277
+
278
+ def _get_embeddings(self, texts: list[str]) -> list[list[float]]:
279
+ """Get embeddings using the configured backend."""
280
+ if self.backend == "openai":
281
+ return self._get_openai_embeddings(texts)
282
+ elif self.backend == "vllm":
283
+ return self._get_vllm_embeddings(texts)
284
+ elif self.backend == "transformers":
285
+ return self._get_transformers_embeddings(texts)
286
+ else:
287
+ raise ValueError(f"Unsupported backend: {self.backend}")
288
+
289
+ def _get_openai_embeddings(self, texts: list[str]) -> list[list[float]]:
290
+ """Get embeddings using OpenAI API."""
291
+ # Assert valid model_name for OpenAI backend
292
+ model_name = self.config["model_name"]
293
+ assert model_name is not None and model_name.strip(), f"Invalid model_name for OpenAI backend: {model_name}. Model name must be provided and non-empty."
294
+
295
+ if self._client is None:
296
+ self._load_openai_client()
297
+
298
+ response = self._client.embeddings.create( # type: ignore
299
+ model=model_name,
300
+ input=texts
301
+ )
302
+ embeddings = [item.embedding for item in response.data]
303
+ return embeddings
304
+
305
+ def _get_vllm_embeddings(self, texts: list[str]) -> list[list[float]]:
306
+ """Get embeddings using vLLM."""
307
+ if self._model is None:
308
+ self._load_model()
309
+
310
+ outputs = self._model.embed(texts) # type: ignore
311
+ embeddings = [o.outputs.embedding for o in outputs]
312
+ return embeddings
313
+
314
+ def _get_transformers_embeddings(self, texts: list[str]) -> list[list[float]]:
315
+ """Get embeddings using transformers directly."""
316
+ if self._model is None:
317
+ self._load_model()
318
+
319
+ if not isinstance(self._model, dict):
320
+ raise ValueError("Model not loaded properly for transformers backend")
321
+
322
+ tokenizer = self._model["tokenizer"]
323
+ model = self._model["model"]
324
+ device = self._model["device"]
325
+
326
+ normalize_embeddings = cast(bool, self.config["transformers_normalize_embeddings"])
327
+
328
+ # For now, use a default max_length
329
+ max_length = 8192
330
+
331
+ # Tokenize
332
+ batch_dict = tokenizer(
333
+ texts,
334
+ padding=True,
335
+ truncation=True,
336
+ max_length=max_length,
337
+ return_tensors="pt",
338
+ )
339
+
340
+ # Move to device
341
+ batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
342
+
343
+ # Run model
344
+ import torch
345
+ with torch.no_grad():
346
+ outputs = model(**batch_dict)
347
+
348
+ # Apply last token pooling
349
+ embeddings = self._last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
350
+
351
+ # Normalize if needed
352
+ if normalize_embeddings:
353
+ import torch.nn.functional as F
354
+ embeddings = F.normalize(embeddings, p=2, dim=1)
355
+
356
+ return embeddings.cpu().numpy().tolist()
357
+
358
+ def _last_token_pool(self, last_hidden_states, attention_mask):
359
+ """Apply last token pooling to get embeddings."""
360
+ import torch
361
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
362
+ if left_padding:
363
+ return last_hidden_states[:, -1]
364
+ else:
365
+ sequence_lengths = attention_mask.sum(dim=1) - 1
366
+ batch_size = last_hidden_states.shape[0]
367
+ return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
368
+
369
+ def _hash_text(self, text: str) -> str:
370
+ return hashlib.sha1(text.encode("utf-8")).hexdigest()
371
+
372
+ def embeds(self, texts: list[str], cache: bool = True) -> np.ndarray:
373
+ """
374
+ Return embeddings for all texts.
375
+
376
+ If cache=True, compute and cache missing embeddings.
377
+ If cache=False, force recompute all embeddings and update cache.
378
+
379
+ This method processes lookups and embedding generation in chunks to
380
+ handle very large input lists. A tqdm progress bar is shown while
381
+ computing missing embeddings.
382
+ """
383
+ if not texts:
384
+ return np.empty((0, 0), dtype=np.float32)
385
+ t = time()
386
+ hashes = [self._hash_text(t) for t in texts]
387
+
388
+ # Helper to yield chunks
389
+ def _chunks(lst: list[str], n: int) -> list[list[str]]:
390
+ return [lst[i : i + n] for i in range(0, len(lst), n)]
391
+
392
+ # Fetch known embeddings in bulk with optimized chunk size
393
+ hit_map: dict[str, np.ndarray] = {}
394
+ chunk_size = self.config["sqlite_chunk_size"]
395
+
396
+ # Use bulk lookup with optimized query
397
+ hash_chunks = _chunks(hashes, chunk_size)
398
+ for chunk in hash_chunks:
399
+ placeholders = ",".join("?" * len(chunk))
400
+ rows = self.conn.execute(
401
+ f"SELECT hash, embedding FROM cache WHERE hash IN ({placeholders})",
402
+ chunk,
403
+ ).fetchall()
404
+ for h, e in rows:
405
+ hit_map[h] = np.frombuffer(e, dtype=np.float32)
406
+
407
+ # Determine which texts are missing
408
+ if cache:
409
+ missing_items: list[tuple[str, str]] = [
410
+ (t, h) for t, h in zip(texts, hashes) if h not in hit_map
411
+ ]
412
+ else:
413
+ missing_items: list[tuple[str, str]] = [
414
+ (t, h) for t, h in zip(texts, hashes)
415
+ ]
416
+
417
+ if missing_items:
418
+ if self.verbose:
419
+ print(f"Computing embeddings for {len(missing_items)} missing texts...")
420
+ missing_texts = [t for t, _ in missing_items]
421
+ embeds = self._get_embeddings(missing_texts)
422
+
423
+ # Prepare batch data for bulk insert
424
+ bulk_insert_data: list[tuple[str, str, bytes]] = []
425
+ for (text, h), vec in zip(missing_items, embeds):
426
+ arr = np.asarray(vec, dtype=np.float32)
427
+ bulk_insert_data.append((h, text, arr.tobytes()))
428
+ hit_map[h] = arr
429
+
430
+ self._bulk_insert(bulk_insert_data)
431
+
432
+ # Return embeddings in the original order
433
+ elapsed = time() - t
434
+ if self.verbose:
435
+ print(f"Retrieved {len(texts)} embeddings in {elapsed:.2f} seconds")
436
+ return np.vstack([hit_map[h] for h in hashes])
437
+
438
+ def __call__(self, texts: list[str], cache: bool = True) -> np.ndarray:
439
+ return self.embeds(texts, cache)
440
+
441
+ def _bulk_insert(self, data: list[tuple[str, str, bytes]]) -> None:
442
+ """Perform bulk insert of embedding data."""
443
+ if not data:
444
+ return
445
+
446
+ self.conn.executemany(
447
+ "INSERT OR REPLACE INTO cache (hash, text, embedding) VALUES (?, ?, ?)",
448
+ data,
449
+ )
450
+ self.conn.commit()
451
+
452
+ def precompute_embeddings(self, texts: list[str]) -> None:
453
+ """
454
+ Precompute embeddings for a large list of texts efficiently.
455
+ This is optimized for bulk operations when you know all texts upfront.
456
+ """
457
+ if not texts:
458
+ return
459
+
460
+ # Remove duplicates while preserving order
461
+ unique_texts = list(dict.fromkeys(texts))
462
+ if self.verbose:
463
+ print(f"Precomputing embeddings for {len(unique_texts)} unique texts...")
464
+
465
+ # Check which ones are already cached
466
+ hashes = [self._hash_text(t) for t in unique_texts]
467
+ existing_hashes = set()
468
+
469
+ # Bulk check for existing embeddings
470
+ chunk_size = self.config["sqlite_chunk_size"]
471
+ for i in range(0, len(hashes), chunk_size):
472
+ chunk = hashes[i : i + chunk_size]
473
+ placeholders = ",".join("?" * len(chunk))
474
+ rows = self.conn.execute(
475
+ f"SELECT hash FROM cache WHERE hash IN ({placeholders})",
476
+ chunk,
477
+ ).fetchall()
478
+ existing_hashes.update(h[0] for h in rows)
479
+
480
+ # Find missing texts
481
+ missing_items = [
482
+ (t, h) for t, h in zip(unique_texts, hashes) if h not in existing_hashes
483
+ ]
484
+
485
+ if not missing_items:
486
+ if self.verbose:
487
+ print("All texts already cached!")
488
+ return
489
+
490
+ if self.verbose:
491
+ print(f"Computing {len(missing_items)} missing embeddings...")
492
+ missing_texts = [t for t, _ in missing_items]
493
+ embeds = self._get_embeddings(missing_texts)
494
+
495
+ # Prepare batch data for bulk insert
496
+ bulk_insert_data: list[tuple[str, str, bytes]] = []
497
+ for (text, h), vec in zip(missing_items, embeds):
498
+ arr = np.asarray(vec, dtype=np.float32)
499
+ bulk_insert_data.append((h, text, arr.tobytes()))
500
+
501
+ self._bulk_insert(bulk_insert_data)
502
+ if self.verbose:
503
+ print(f"Successfully cached {len(missing_items)} new embeddings!")
504
+
505
+ def get_cache_stats(self) -> dict[str, int]:
506
+ """Get statistics about the cache."""
507
+ cursor = self.conn.execute("SELECT COUNT(*) FROM cache")
508
+ count = cursor.fetchone()[0]
509
+ return {"total_cached": count}
510
+
511
+ def clear_cache(self) -> None:
512
+ """Clear all cached embeddings."""
513
+ self.conn.execute("DELETE FROM cache")
514
+ self.conn.commit()
515
+
516
+ def get_config(self) -> Dict[str, Any]:
517
+ """Get current configuration."""
518
+ return {
519
+ "url_or_model": self.url_or_model,
520
+ "backend": self.backend,
521
+ "embed_size": self.embed_size,
522
+ "db_path": str(self.db_path),
523
+ "verbose": self.verbose,
524
+ "lazy": self.lazy,
525
+ **self.config
526
+ }
527
+
528
+ def update_config(self, **kwargs) -> None:
529
+ """Update configuration parameters."""
530
+ for key, value in kwargs.items():
531
+ if key in self.config:
532
+ self.config[key] = value
533
+ elif key == "verbose":
534
+ self.verbose = value
535
+ elif key == "lazy":
536
+ self.lazy = value
537
+ else:
538
+ raise ValueError(f"Unknown configuration parameter: {key}")
539
+
540
+ # Reset model if backend-specific parameters changed
541
+ backend_params = {
542
+ "vllm": ["vllm_gpu_memory_utilization", "vllm_tensor_parallel_size", "vllm_dtype",
543
+ "vllm_trust_remote_code", "vllm_max_model_len"],
544
+ "transformers": ["transformers_device", "transformers_batch_size",
545
+ "transformers_normalize_embeddings", "transformers_trust_remote_code"],
546
+ "openai": ["api_key", "model_name"]
547
+ }
548
+
549
+ if any(param in kwargs for param in backend_params.get(self.backend, [])):
550
+ self._model = None # Force reload on next use
551
+ if self.backend == "openai":
552
+ self._client = None
553
+
554
+ def __del__(self) -> None:
555
+ """Clean up database connection."""
556
+ if hasattr(self, "conn"):
557
+ self.conn.close()
@@ -0,0 +1,15 @@
1
+ """Type definitions for the embed_cache package."""
2
+
3
+ from typing import List, Dict, Any, Union, Optional, Tuple
4
+ import numpy as np
5
+ from numpy.typing import NDArray
6
+
7
+ # Type aliases
8
+ TextList = List[str]
9
+ EmbeddingArray = NDArray[np.float32]
10
+ EmbeddingList = List[List[float]]
11
+ CacheStats = Dict[str, int]
12
+ ModelIdentifier = str # Either URL or model name/path
13
+
14
+ # For backwards compatibility
15
+ Embeddings = Union[EmbeddingArray, EmbeddingList]
@@ -0,0 +1,42 @@
1
+ """Utility functions for the embed_cache package."""
2
+
3
+ import os
4
+ from typing import Optional
5
+
6
+ def get_default_cache_path() -> str:
7
+ """Get the default cache path based on environment."""
8
+ cache_dir = os.getenv("EMBED_CACHE_DIR", ".")
9
+ return os.path.join(cache_dir, "embed_cache.sqlite")
10
+
11
+ def validate_model_name(model_name: str) -> bool:
12
+ """Validate if a model name is supported."""
13
+ # Check if it's a URL
14
+ if model_name.startswith("http"):
15
+ return True
16
+
17
+ # Check if it's a valid model path/name
18
+ supported_prefixes = [
19
+ "Qwen/",
20
+ "sentence-transformers/",
21
+ "BAAI/",
22
+ "intfloat/",
23
+ "microsoft/",
24
+ "nvidia/",
25
+ ]
26
+
27
+ return any(model_name.startswith(prefix) for prefix in supported_prefixes) or os.path.exists(model_name)
28
+
29
+ def estimate_cache_size(num_texts: int, embedding_dim: int = 1024) -> str:
30
+ """Estimate cache size for given number of texts."""
31
+ # Rough estimate: hash (40 bytes) + text (avg 100 bytes) + embedding (embedding_dim * 4 bytes)
32
+ bytes_per_entry = 40 + 100 + (embedding_dim * 4)
33
+ total_bytes = num_texts * bytes_per_entry
34
+
35
+ if total_bytes < 1024:
36
+ return f"{total_bytes} bytes"
37
+ elif total_bytes < 1024 * 1024:
38
+ return f"{total_bytes / 1024:.1f} KB"
39
+ elif total_bytes < 1024 * 1024 * 1024:
40
+ return f"{total_bytes / (1024 * 1024):.1f} MB"
41
+ else:
42
+ return f"{total_bytes / (1024 * 1024 * 1024):.1f} GB"
@@ -18,7 +18,7 @@
18
18
  # • memoize(func) -> Callable - Function result caching decorator
19
19
  # • identify(obj: Any) -> str - Generate unique object identifier
20
20
  # • identify_uuid(obj: Any) -> str - Generate UUID-based object identifier
21
- # • load_by_ext(fname: str | list[str]) -> Any - Auto-detect file format loader
21
+ # • load_by_ext(fname: Union[str, list[str]]) -> Any - Auto-detect file format loader
22
22
  # • dump_json_or_pickle(obj: Any, fname: str) -> None - Smart file serializer
23
23
  # • load_json_or_pickle(fname: str) -> Any - Smart file deserializer
24
24
  # • multi_thread(func, items, **kwargs) -> list - Parallel thread execution
@@ -5,7 +5,7 @@ import re
5
5
  import sys
6
6
  import time
7
7
  from collections import OrderedDict
8
- from typing import Annotated, Literal
8
+ from typing import Annotated, Literal, Union
9
9
 
10
10
  from loguru import logger
11
11
 
@@ -166,7 +166,7 @@ def log(
166
166
  *,
167
167
  level: Literal["info", "warning", "error", "critical", "success"] = "info",
168
168
  once: bool = False,
169
- interval: float | None = None,
169
+ interval: Union[float, None] = None,
170
170
  ) -> None:
171
171
  """
172
172
  Log a message using loguru with optional `once` and `interval` control.
@@ -7,7 +7,7 @@ import pickle
7
7
  import time
8
8
  from glob import glob
9
9
  from pathlib import Path
10
- from typing import Any
10
+ from typing import Any, Union
11
11
 
12
12
  from json_repair import loads as jloads
13
13
  from pydantic import BaseModel
@@ -92,7 +92,7 @@ def load_jsonl(path):
92
92
  return [json.loads(line) for line in lines]
93
93
 
94
94
 
95
- def load_by_ext(fname: str | list[str], do_memoize: bool = False) -> Any:
95
+ def load_by_ext(fname: Union[str, list[str]], do_memoize: bool = False) -> Any:
96
96
  """
97
97
  Load data based on file extension.
98
98
  """
@@ -3,7 +3,7 @@
3
3
  import copy
4
4
  import pprint
5
5
  import textwrap
6
- from typing import Any
6
+ from typing import Any, Union
7
7
 
8
8
  from tabulate import tabulate
9
9
 
@@ -24,17 +24,17 @@ def flatten_dict(d, parent_key="", sep="."):
24
24
 
25
25
  def fprint(
26
26
  input_data: Any,
27
- key_ignore: list[str] | None = None,
28
- key_keep: list[str] | None = None,
27
+ key_ignore: Union[list[str], None] = None,
28
+ key_keep: Union[list[str], None] = None,
29
29
  max_width: int = 100,
30
30
  indent: int = 2,
31
- depth: int | None = None,
31
+ depth: Union[int, None] = None,
32
32
  table_format: str = "grid",
33
33
  str_wrap_width: int = 80,
34
34
  grep=None,
35
35
  is_notebook=None,
36
36
  f=print,
37
- ) -> None | str:
37
+ ) -> Union[None, str]:
38
38
  """
39
39
  Pretty print structured data.
40
40
  """
@@ -4,7 +4,7 @@ import traceback
4
4
  from collections.abc import Callable, Iterable, Iterator
5
5
  from concurrent.futures import ProcessPoolExecutor, as_completed
6
6
  from itertools import islice
7
- from typing import Any, TypeVar, cast
7
+ from typing import Any, TypeVar, Union, cast
8
8
 
9
9
  T = TypeVar("T")
10
10
 
@@ -65,12 +65,12 @@ def multi_process(
65
65
  func: Callable[[Any], Any],
66
66
  inputs: Iterable[Any],
67
67
  *,
68
- workers: int | None = None,
68
+ workers: Union[int, None] = None,
69
69
  batch: int = 1,
70
70
  ordered: bool = True,
71
71
  progress: bool = False,
72
- inflight: int | None = None,
73
- timeout: float | None = None,
72
+ inflight: Union[int, None] = None,
73
+ timeout: Union[float, None] = None,
74
74
  stop_on_error: bool = True,
75
75
  process_update_interval=10,
76
76
  for_loop: bool = False,
@@ -83,7 +83,7 @@ import traceback
83
83
  from collections.abc import Callable, Iterable
84
84
  from concurrent.futures import ThreadPoolExecutor, as_completed
85
85
  from itertools import islice
86
- from typing import Any, TypeVar
86
+ from typing import Any, TypeVar, Union
87
87
 
88
88
  from loguru import logger
89
89
 
@@ -125,16 +125,16 @@ def multi_thread(
125
125
  func: Callable,
126
126
  inputs: Iterable[Any],
127
127
  *,
128
- workers: int | None = DEFAULT_WORKERS,
128
+ workers: Union[int, None] = DEFAULT_WORKERS,
129
129
  batch: int = 1,
130
130
  ordered: bool = True,
131
131
  progress: bool = True,
132
132
  progress_update: int = 10,
133
133
  prefetch_factor: int = 4,
134
- timeout: float | None = None,
134
+ timeout: Union[float, None] = None,
135
135
  stop_on_error: bool = True,
136
136
  n_proc=0,
137
- store_output_pkl_file: str | None = None,
137
+ store_output_pkl_file: Union[str, None] = None,
138
138
  **fixed_kwargs,
139
139
  ) -> list[Any]:
140
140
  """
File without changes