embedkit 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- embedkit/__init__.py +10 -12
- embedkit/base.py +73 -4
- embedkit/classes.py +1 -6
- embedkit/models.py +7 -1
- embedkit/providers/cohere.py +22 -43
- embedkit/providers/colpali.py +64 -67
- embedkit/utils.py +112 -19
- {embedkit-0.1.4.dist-info → embedkit-0.1.5.dist-info}/METADATA +18 -12
- embedkit-0.1.5.dist-info/RECORD +13 -0
- embedkit-0.1.4.dist-info/RECORD +0 -13
- {embedkit-0.1.4.dist-info → embedkit-0.1.5.dist-info}/WHEEL +0 -0
- {embedkit-0.1.4.dist-info → embedkit-0.1.5.dist-info}/licenses/LICENSE +0 -0
embedkit/__init__.py
CHANGED
@@ -5,7 +5,6 @@ EmbedKit: A unified toolkit for generating vector embeddings.
|
|
5
5
|
|
6
6
|
from typing import Union, List, Optional
|
7
7
|
from pathlib import Path
|
8
|
-
import numpy as np
|
9
8
|
|
10
9
|
from .models import Model
|
11
10
|
from .base import EmbeddingError, EmbeddingResponse
|
@@ -28,7 +27,7 @@ class EmbedKit:
|
|
28
27
|
@classmethod
|
29
28
|
def colpali(
|
30
29
|
cls,
|
31
|
-
model: Model = Model.ColPali.
|
30
|
+
model: Model = Model.ColPali.COLPALI_V1_3,
|
32
31
|
device: Optional[str] = None,
|
33
32
|
text_batch_size: int = 32,
|
34
33
|
image_batch_size: int = 8,
|
@@ -42,13 +41,13 @@ class EmbedKit:
|
|
42
41
|
text_batch_size: Batch size for text embedding generation
|
43
42
|
image_batch_size: Batch size for image embedding generation
|
44
43
|
"""
|
45
|
-
if model
|
46
|
-
|
47
|
-
|
48
|
-
|
44
|
+
if not isinstance(model, Model.ColPali):
|
45
|
+
raise ValueError(
|
46
|
+
f"Unsupported model: {model}. Must be a Model.ColPali enum value."
|
47
|
+
)
|
49
48
|
|
50
49
|
provider = ColPaliProvider(
|
51
|
-
|
50
|
+
model=model,
|
52
51
|
device=device,
|
53
52
|
text_batch_size=text_batch_size,
|
54
53
|
image_batch_size=image_batch_size,
|
@@ -77,16 +76,15 @@ class EmbedKit:
|
|
77
76
|
if not api_key:
|
78
77
|
raise ValueError("API key is required")
|
79
78
|
|
80
|
-
if model
|
81
|
-
model_name = "embed-v4.0"
|
82
|
-
else:
|
79
|
+
if not isinstance(model, Model.Cohere):
|
83
80
|
raise ValueError(f"Unsupported model: {model}")
|
84
81
|
|
85
82
|
provider = CohereProvider(
|
86
|
-
api_key=api_key,
|
83
|
+
api_key=api_key,
|
84
|
+
model=model,
|
87
85
|
text_batch_size=text_batch_size,
|
88
86
|
image_batch_size=image_batch_size,
|
89
|
-
text_input_type=text_input_type
|
87
|
+
text_input_type=text_input_type,
|
90
88
|
)
|
91
89
|
return cls(provider)
|
92
90
|
|
embedkit/base.py
CHANGED
@@ -7,11 +7,15 @@ from pathlib import Path
|
|
7
7
|
import numpy as np
|
8
8
|
from dataclasses import dataclass
|
9
9
|
|
10
|
+
from .models import Model
|
11
|
+
from .utils import with_pdf_cleanup
|
12
|
+
|
10
13
|
|
11
14
|
@dataclass
|
12
15
|
class EmbeddingObject:
|
13
16
|
embedding: np.ndarray
|
14
17
|
source_b64: str = None
|
18
|
+
source_content_type: str = None # e.g., "image/png", "image/jpeg"
|
15
19
|
|
16
20
|
|
17
21
|
@dataclass
|
@@ -29,6 +33,67 @@ class EmbeddingResponse:
|
|
29
33
|
class EmbeddingProvider(ABC):
|
30
34
|
"""Abstract base class for embedding providers."""
|
31
35
|
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
model_name: str,
|
39
|
+
text_batch_size: int,
|
40
|
+
image_batch_size: int,
|
41
|
+
provider_name: str,
|
42
|
+
):
|
43
|
+
self.model_name = model_name
|
44
|
+
self.provider_name = provider_name
|
45
|
+
self.text_batch_size = text_batch_size
|
46
|
+
self.image_batch_size = image_batch_size
|
47
|
+
|
48
|
+
def _normalize_text_input(self, texts: Union[str, List[str]]) -> List[str]:
|
49
|
+
"""Normalize text input to a list of strings."""
|
50
|
+
if isinstance(texts, str):
|
51
|
+
return [texts]
|
52
|
+
return texts
|
53
|
+
|
54
|
+
def _normalize_image_input(
|
55
|
+
self, images: Union[Path, str, List[Union[Path, str]]]
|
56
|
+
) -> List[Path]:
|
57
|
+
"""Normalize image input to a list of Path objects."""
|
58
|
+
if isinstance(images, (str, Path)):
|
59
|
+
return [Path(images)]
|
60
|
+
return [Path(img) for img in images]
|
61
|
+
|
62
|
+
def _create_text_response(
|
63
|
+
self, embeddings: List[np.ndarray], input_type: str = "text"
|
64
|
+
) -> EmbeddingResponse:
|
65
|
+
"""Create a standardized text embedding response."""
|
66
|
+
return EmbeddingResponse(
|
67
|
+
model_name=self.model_name,
|
68
|
+
model_provider=self.provider_name,
|
69
|
+
input_type=input_type,
|
70
|
+
objects=[EmbeddingObject(embedding=e) for e in embeddings],
|
71
|
+
)
|
72
|
+
|
73
|
+
def _create_image_response(
|
74
|
+
self,
|
75
|
+
embeddings: List[np.ndarray],
|
76
|
+
b64_data: List[str],
|
77
|
+
content_types: List[str],
|
78
|
+
input_type: str = "image",
|
79
|
+
) -> EmbeddingResponse:
|
80
|
+
"""Create a standardized image embedding response."""
|
81
|
+
return EmbeddingResponse(
|
82
|
+
model_name=self.model_name,
|
83
|
+
model_provider=self.provider_name,
|
84
|
+
input_type=input_type,
|
85
|
+
objects=[
|
86
|
+
EmbeddingObject(
|
87
|
+
embedding=embedding,
|
88
|
+
source_b64=b64_data,
|
89
|
+
source_content_type=content_type,
|
90
|
+
)
|
91
|
+
for embedding, b64_data, content_type in zip(
|
92
|
+
embeddings, b64_data, content_types
|
93
|
+
)
|
94
|
+
],
|
95
|
+
)
|
96
|
+
|
32
97
|
@abstractmethod
|
33
98
|
def embed_text(self, texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
|
34
99
|
"""Generate document text embeddings using the configured provider."""
|
@@ -41,10 +106,14 @@ class EmbeddingProvider(ABC):
|
|
41
106
|
"""Generate image embeddings using the configured provider."""
|
42
107
|
pass
|
43
108
|
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
109
|
+
def embed_pdf(self, pdf_path: Path) -> EmbeddingResponse:
|
110
|
+
"""Generate embeddings for a PDF file."""
|
111
|
+
return self._embed_pdf_impl(pdf_path)
|
112
|
+
|
113
|
+
@with_pdf_cleanup
|
114
|
+
def _embed_pdf_impl(self, pdf_path: List[Path]) -> EmbeddingResponse:
|
115
|
+
"""Internal implementation of PDF embedding with cleanup handled by decorator."""
|
116
|
+
return self.embed_image(pdf_path)
|
48
117
|
|
49
118
|
|
50
119
|
class EmbeddingError(Exception):
|
embedkit/classes.py
CHANGED
@@ -13,9 +13,4 @@ from . import EmbeddingResponse, EmbeddingError
|
|
13
13
|
from .models import Model
|
14
14
|
from .providers.cohere import CohereInputType
|
15
15
|
|
16
|
-
__all__ = [
|
17
|
-
"EmbeddingResponse",
|
18
|
-
"EmbeddingError",
|
19
|
-
"Model",
|
20
|
-
"CohereInputType"
|
21
|
-
]
|
16
|
+
__all__ = ["EmbeddingResponse", "EmbeddingError", "Model", "CohereInputType"]
|
embedkit/models.py
CHANGED
@@ -6,7 +6,13 @@ from enum import Enum
|
|
6
6
|
|
7
7
|
class Model:
|
8
8
|
class ColPali(Enum):
|
9
|
-
|
9
|
+
COLPALI_V1_3 = "vidore/colpali-v1.3"
|
10
|
+
COLSMOL_500M = "vidore/colSmol-500M"
|
11
|
+
COLSMOL_256M = "vidore/colSmol-256M"
|
10
12
|
|
11
13
|
class Cohere(Enum):
|
12
14
|
EMBED_V4_0 = "embed-v4.0"
|
15
|
+
EMBED_ENGLISH_V3_0 = "embed-english-v3.0"
|
16
|
+
EMBED_ENGLISH_LIGHT_V3_0 = "embed-english-light-v3.0"
|
17
|
+
EMBED_MULTILINGUAL_V3_0 = "embed-multilingual-v3.0"
|
18
|
+
EMBED_MULTILINGUAL_LIGHT_V3_0 = "embed-multilingual-light-v3.0"
|
embedkit/providers/cohere.py
CHANGED
@@ -6,7 +6,8 @@ from pathlib import Path
|
|
6
6
|
import numpy as np
|
7
7
|
from enum import Enum
|
8
8
|
|
9
|
-
from ..
|
9
|
+
from ..models import Model
|
10
|
+
from ..utils import image_to_base64
|
10
11
|
from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResponse, EmbeddingObject
|
11
12
|
|
12
13
|
|
@@ -23,18 +24,20 @@ class CohereProvider(EmbeddingProvider):
|
|
23
24
|
def __init__(
|
24
25
|
self,
|
25
26
|
api_key: str,
|
26
|
-
|
27
|
+
model: Model.Cohere,
|
27
28
|
text_batch_size: int,
|
28
29
|
image_batch_size: int,
|
29
30
|
text_input_type: CohereInputType = CohereInputType.SEARCH_DOCUMENT,
|
30
31
|
):
|
32
|
+
super().__init__(
|
33
|
+
model_name=model.value,
|
34
|
+
text_batch_size=text_batch_size,
|
35
|
+
image_batch_size=image_batch_size,
|
36
|
+
provider_name="Cohere",
|
37
|
+
)
|
31
38
|
self.api_key = api_key
|
32
|
-
self.model_name = model_name
|
33
|
-
self.text_batch_size = text_batch_size
|
34
|
-
self.image_batch_size = image_batch_size
|
35
39
|
self.input_type = text_input_type
|
36
40
|
self._client = None
|
37
|
-
self.provider_name = "Cohere"
|
38
41
|
|
39
42
|
def _get_client(self):
|
40
43
|
"""Lazy load the Cohere client."""
|
@@ -54,9 +57,7 @@ class CohereProvider(EmbeddingProvider):
|
|
54
57
|
def embed_text(self, texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
|
55
58
|
"""Generate text embeddings using the Cohere API."""
|
56
59
|
client = self._get_client()
|
57
|
-
|
58
|
-
if isinstance(texts, str):
|
59
|
-
texts = [texts]
|
60
|
+
texts = self._normalize_text_input(texts)
|
60
61
|
|
61
62
|
try:
|
62
63
|
all_embeddings = []
|
@@ -72,16 +73,7 @@ class CohereProvider(EmbeddingProvider):
|
|
72
73
|
)
|
73
74
|
all_embeddings.extend(np.array(response.embeddings.float_))
|
74
75
|
|
75
|
-
return
|
76
|
-
model_name=self.model_name,
|
77
|
-
model_provider=self.provider_name,
|
78
|
-
input_type=self.input_type.value,
|
79
|
-
objects=[
|
80
|
-
EmbeddingObject(
|
81
|
-
embedding=e,
|
82
|
-
) for e in all_embeddings
|
83
|
-
]
|
84
|
-
)
|
76
|
+
return self._create_text_response(all_embeddings, self.input_type.value)
|
85
77
|
|
86
78
|
except Exception as e:
|
87
79
|
raise EmbeddingError(f"Failed to embed text with Cohere: {e}") from e
|
@@ -92,12 +84,7 @@ class CohereProvider(EmbeddingProvider):
|
|
92
84
|
) -> EmbeddingResponse:
|
93
85
|
"""Generate embeddings for images using Cohere API."""
|
94
86
|
client = self._get_client()
|
95
|
-
|
96
|
-
|
97
|
-
if isinstance(images, (str, Path)):
|
98
|
-
images = [Path(images)]
|
99
|
-
else:
|
100
|
-
images = [Path(img) for img in images]
|
87
|
+
images = self._normalize_image_input(images)
|
101
88
|
|
102
89
|
try:
|
103
90
|
all_embeddings = []
|
@@ -111,7 +98,11 @@ class CohereProvider(EmbeddingProvider):
|
|
111
98
|
for image in batch_images:
|
112
99
|
if not image.exists():
|
113
100
|
raise EmbeddingError(f"Image not found: {image}")
|
114
|
-
|
101
|
+
b64_data, content_type = image_to_base64(image)
|
102
|
+
# Construct full data URI for API
|
103
|
+
data_uri = f"data:{content_type};base64,{b64_data}"
|
104
|
+
b64_images.append(data_uri)
|
105
|
+
all_b64_images.append((b64_data, content_type))
|
115
106
|
|
116
107
|
response = client.embed(
|
117
108
|
model=self.model_name,
|
@@ -121,24 +112,12 @@ class CohereProvider(EmbeddingProvider):
|
|
121
112
|
)
|
122
113
|
|
123
114
|
all_embeddings.extend(np.array(response.embeddings.float_))
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
input_type=input_type,
|
130
|
-
objects=[
|
131
|
-
EmbeddingObject(
|
132
|
-
embedding=all_embeddings[i],
|
133
|
-
source_b64=all_b64_images[i]
|
134
|
-
) for i in range(len(all_embeddings))
|
135
|
-
]
|
115
|
+
|
116
|
+
return self._create_image_response(
|
117
|
+
all_embeddings,
|
118
|
+
[b64 for b64, _ in all_b64_images],
|
119
|
+
[content_type for _, content_type in all_b64_images],
|
136
120
|
)
|
137
121
|
|
138
122
|
except Exception as e:
|
139
123
|
raise EmbeddingError(f"Failed to embed image with Cohere: {e}") from e
|
140
|
-
|
141
|
-
def embed_pdf(self, pdf_path: Path) -> EmbeddingResponse:
|
142
|
-
"""Generate embeddings for a PDF file using Cohere API."""
|
143
|
-
image_paths = pdf_to_images(pdf_path)
|
144
|
-
return self.embed_image(image_paths)
|
embedkit/providers/colpali.py
CHANGED
@@ -8,8 +8,9 @@ import numpy as np
|
|
8
8
|
import torch
|
9
9
|
from PIL import Image
|
10
10
|
|
11
|
-
from ..
|
12
|
-
from ..
|
11
|
+
from ..models import Model
|
12
|
+
from ..utils import image_to_base64
|
13
|
+
from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResponse
|
13
14
|
|
14
15
|
logger = logging.getLogger(__name__)
|
15
16
|
|
@@ -19,15 +20,17 @@ class ColPaliProvider(EmbeddingProvider):
|
|
19
20
|
|
20
21
|
def __init__(
|
21
22
|
self,
|
22
|
-
|
23
|
+
model: Model.ColPali,
|
23
24
|
text_batch_size: int,
|
24
25
|
image_batch_size: int,
|
25
26
|
device: Optional[str] = None,
|
26
27
|
):
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
28
|
+
super().__init__(
|
29
|
+
model_name=model.value,
|
30
|
+
text_batch_size=text_batch_size,
|
31
|
+
image_batch_size=image_batch_size,
|
32
|
+
provider_name="ColPali",
|
33
|
+
)
|
31
34
|
|
32
35
|
# Auto-detect device
|
33
36
|
if device is None:
|
@@ -38,24 +41,43 @@ class ColPaliProvider(EmbeddingProvider):
|
|
38
41
|
else:
|
39
42
|
device = "cpu"
|
40
43
|
|
41
|
-
self.
|
42
|
-
self.
|
43
|
-
self.
|
44
|
+
self._hf_device = device
|
45
|
+
self._hf_model = None
|
46
|
+
self._hf_processor = None
|
44
47
|
|
45
48
|
def _load_model(self):
|
46
49
|
"""Lazy load the model."""
|
47
|
-
if self.
|
50
|
+
if self._hf_model is None:
|
48
51
|
try:
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
self.
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
52
|
+
if self.model_name in [Model.ColPali.COLPALI_V1_3.value]:
|
53
|
+
from colpali_engine.models import ColPali, ColPaliProcessor
|
54
|
+
|
55
|
+
self._hf_model = ColPali.from_pretrained(
|
56
|
+
self.model_name,
|
57
|
+
torch_dtype=torch.bfloat16,
|
58
|
+
device_map=self._hf_device,
|
59
|
+
).eval()
|
60
|
+
|
61
|
+
self._hf_processor = ColPaliProcessor.from_pretrained(self.model_name)
|
62
|
+
|
63
|
+
elif self.model_name in [
|
64
|
+
Model.ColPali.COLSMOL_500M.value,
|
65
|
+
Model.ColPali.COLSMOL_256M.value,
|
66
|
+
]:
|
67
|
+
from colpali_engine.models import ColIdefics3, ColIdefics3Processor
|
68
|
+
|
69
|
+
self._hf_model = ColIdefics3.from_pretrained(
|
70
|
+
self.model_name,
|
71
|
+
torch_dtype=torch.bfloat16,
|
72
|
+
device_map=self._hf_device,
|
73
|
+
).eval()
|
74
|
+
self._hf_processor = ColIdefics3Processor.from_pretrained(
|
75
|
+
self.model_name
|
76
|
+
)
|
77
|
+
else:
|
78
|
+
raise ValueError(f"Unable to load model for: {self.model_name}.")
|
79
|
+
|
80
|
+
logger.info(f"Loaded {self.model_name} on {self._hf_device}")
|
59
81
|
|
60
82
|
except ImportError as e:
|
61
83
|
raise EmbeddingError(
|
@@ -64,38 +86,26 @@ class ColPaliProvider(EmbeddingProvider):
|
|
64
86
|
except Exception as e:
|
65
87
|
raise EmbeddingError(f"Failed to load model: {e}") from e
|
66
88
|
|
67
|
-
def embed_text(self, texts: Union[str, List[str]]) -> EmbeddingResponse:
|
89
|
+
def embed_text(self, texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
|
68
90
|
"""Generate embeddings for text inputs."""
|
69
91
|
self._load_model()
|
70
|
-
|
71
|
-
if isinstance(texts, str):
|
72
|
-
texts = [texts]
|
92
|
+
texts = self._normalize_text_input(texts)
|
73
93
|
|
74
94
|
try:
|
75
95
|
# Process texts in batches
|
76
|
-
all_embeddings = []
|
96
|
+
all_embeddings: List[np.ndarray] = []
|
77
97
|
|
78
98
|
for i in range(0, len(texts), self.text_batch_size):
|
79
99
|
batch_texts = texts[i : i + self.text_batch_size]
|
80
|
-
processed = self.
|
100
|
+
processed = self._hf_processor.process_queries(batch_texts).to(self._hf_device)
|
81
101
|
|
82
102
|
with torch.no_grad():
|
83
|
-
batch_embeddings = self.
|
103
|
+
batch_embeddings = self._hf_model(**processed)
|
84
104
|
all_embeddings.append(batch_embeddings.cpu().float().numpy())
|
85
105
|
|
86
106
|
# Concatenate all batch embeddings
|
87
107
|
final_embeddings = np.concatenate(all_embeddings, axis=0)
|
88
|
-
|
89
|
-
return EmbeddingResponse(
|
90
|
-
model_name=self.model_name,
|
91
|
-
model_provider=self.provider_name,
|
92
|
-
input_type="text",
|
93
|
-
objects=[
|
94
|
-
EmbeddingObject(
|
95
|
-
embedding=e,
|
96
|
-
) for e in final_embeddings
|
97
|
-
]
|
98
|
-
)
|
108
|
+
return self._create_text_response(final_embeddings)
|
99
109
|
|
100
110
|
except Exception as e:
|
101
111
|
raise EmbeddingError(f"Failed to embed text: {e}") from e
|
@@ -105,21 +115,19 @@ class ColPaliProvider(EmbeddingProvider):
|
|
105
115
|
) -> EmbeddingResponse:
|
106
116
|
"""Generate embeddings for images."""
|
107
117
|
self._load_model()
|
108
|
-
|
109
|
-
if isinstance(images, (str, Path)):
|
110
|
-
images = [Path(images)]
|
111
|
-
else:
|
112
|
-
images = [Path(img) for img in images]
|
118
|
+
images = self._normalize_image_input(images)
|
113
119
|
|
114
120
|
try:
|
115
121
|
# Process images in batches
|
116
|
-
all_embeddings = []
|
117
|
-
|
122
|
+
all_embeddings: List[np.ndarray] = []
|
123
|
+
all_b64_data: List[str] = []
|
124
|
+
all_content_types: List[str] = []
|
118
125
|
|
119
126
|
for i in range(0, len(images), self.image_batch_size):
|
120
127
|
batch_images = images[i : i + self.image_batch_size]
|
121
128
|
pil_images = []
|
122
|
-
|
129
|
+
batch_b64_data = []
|
130
|
+
batch_content_types = []
|
123
131
|
|
124
132
|
for img_path in batch_images:
|
125
133
|
if not img_path.exists():
|
@@ -127,34 +135,23 @@ class ColPaliProvider(EmbeddingProvider):
|
|
127
135
|
|
128
136
|
with Image.open(img_path) as img:
|
129
137
|
pil_images.append(img.convert("RGB"))
|
130
|
-
|
138
|
+
b64, content_type = image_to_base64(img_path)
|
139
|
+
batch_b64_data.append(b64)
|
140
|
+
batch_content_types.append(content_type)
|
131
141
|
|
132
|
-
processed = self.
|
142
|
+
processed = self._hf_processor.process_images(pil_images).to(self._hf_device)
|
133
143
|
|
134
144
|
with torch.no_grad():
|
135
|
-
batch_embeddings = self.
|
145
|
+
batch_embeddings = self._hf_model(**processed)
|
136
146
|
all_embeddings.append(batch_embeddings.cpu().float().numpy())
|
137
|
-
|
147
|
+
all_b64_data.extend(batch_b64_data)
|
148
|
+
all_content_types.extend(batch_content_types)
|
138
149
|
|
139
150
|
# Concatenate all batch embeddings
|
140
151
|
final_embeddings = np.concatenate(all_embeddings, axis=0)
|
141
|
-
|
142
|
-
|
143
|
-
model_name=self.model_name,
|
144
|
-
model_provider=self.provider_name,
|
145
|
-
input_type="image",
|
146
|
-
objects=[
|
147
|
-
EmbeddingObject(
|
148
|
-
embedding=final_embeddings[i],
|
149
|
-
source_b64=all_b64_images[i]
|
150
|
-
) for i in range(len(final_embeddings))
|
151
|
-
]
|
152
|
+
return self._create_image_response(
|
153
|
+
final_embeddings, all_b64_data, all_content_types
|
152
154
|
)
|
153
155
|
|
154
156
|
except Exception as e:
|
155
157
|
raise EmbeddingError(f"Failed to embed images: {e}") from e
|
156
|
-
|
157
|
-
def embed_pdf(self, pdf_path: Path) -> EmbeddingResponse:
|
158
|
-
"""Generate embeddings for a PDF file using ColPali API."""
|
159
|
-
images = pdf_to_images(pdf_path)
|
160
|
-
return self.embed_image(images)
|
embedkit/utils.py
CHANGED
@@ -1,37 +1,84 @@
|
|
1
|
+
import tempfile
|
2
|
+
import shutil
|
3
|
+
import logging
|
4
|
+
from contextlib import contextmanager
|
1
5
|
from pdf2image import convert_from_path
|
2
6
|
from pathlib import Path
|
3
7
|
from .config import get_temp_dir
|
4
|
-
from typing import Union
|
8
|
+
from typing import Union, List, Iterator, Callable, TypeVar, Any
|
5
9
|
|
10
|
+
logger = logging.getLogger(__name__)
|
6
11
|
|
7
|
-
|
8
|
-
"""Convert a PDF file to a list of images."""
|
9
|
-
root_temp_dir = get_temp_dir()
|
10
|
-
img_temp_dir = root_temp_dir / "images"
|
11
|
-
img_temp_dir.mkdir(parents=True, exist_ok=True)
|
12
|
-
images = convert_from_path(pdf_path=str(pdf_path), output_folder=str(img_temp_dir))
|
13
|
-
image_paths = []
|
12
|
+
T = TypeVar("T")
|
14
13
|
|
15
|
-
for i, image in enumerate(images):
|
16
|
-
output_path = img_temp_dir / f"{pdf_path.stem}_{i}.png"
|
17
|
-
if output_path.exists():
|
18
|
-
output_path.unlink()
|
19
14
|
|
20
|
-
|
21
|
-
|
22
|
-
|
15
|
+
@contextmanager
|
16
|
+
def temporary_directory() -> Iterator[Path]:
|
17
|
+
"""Create a temporary directory that is automatically cleaned up when done.
|
23
18
|
|
19
|
+
Yields:
|
20
|
+
Path: Path to the temporary directory
|
21
|
+
"""
|
22
|
+
temp_dir = Path(tempfile.mkdtemp())
|
23
|
+
try:
|
24
|
+
yield temp_dir
|
25
|
+
finally:
|
26
|
+
shutil.rmtree(temp_dir)
|
27
|
+
|
28
|
+
|
29
|
+
def pdf_to_images(pdf_path: Path) -> List[Path]:
|
30
|
+
"""Convert a PDF file to a list of images.
|
31
|
+
|
32
|
+
The images are stored in a temporary directory that will be automatically
|
33
|
+
cleaned up when the process exits.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
pdf_path: Path to the PDF file
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
List[Path]: List of paths to the generated images
|
40
|
+
|
41
|
+
Note:
|
42
|
+
The temporary files will be automatically cleaned up when the process exits.
|
43
|
+
Do not rely on these files persisting after the function returns.
|
44
|
+
"""
|
45
|
+
with temporary_directory() as temp_dir:
|
46
|
+
images = convert_from_path(pdf_path=str(pdf_path), output_folder=str(temp_dir))
|
47
|
+
image_paths = []
|
48
|
+
|
49
|
+
for i, image in enumerate(images):
|
50
|
+
output_path = temp_dir / f"{pdf_path.stem}_{i}.png"
|
51
|
+
image.save(output_path)
|
52
|
+
final_path = Path(tempfile.mktemp(suffix=".png"))
|
53
|
+
shutil.move(output_path, final_path)
|
54
|
+
image_paths.append(final_path)
|
55
|
+
|
56
|
+
return image_paths
|
57
|
+
|
58
|
+
|
59
|
+
def image_to_base64(image_path: Union[str, Path]) -> tuple[str, str]:
|
60
|
+
"""Convert an image to base64 and return the base64 data and content type.
|
24
61
|
|
25
|
-
|
62
|
+
Args:
|
63
|
+
image_path: Path to the image file
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
tuple[str, str]: (base64_data, content_type)
|
67
|
+
|
68
|
+
Raises:
|
69
|
+
ValueError: If the image cannot be read or has an unsupported format
|
70
|
+
"""
|
26
71
|
import base64
|
27
72
|
|
28
73
|
try:
|
29
|
-
|
74
|
+
base64_data = base64.b64encode(Path(image_path).read_bytes()).decode("utf-8")
|
30
75
|
except Exception as e:
|
31
76
|
raise ValueError(f"Failed to read image {image_path}: {e}") from e
|
32
77
|
|
33
78
|
if isinstance(image_path, Path):
|
34
79
|
image_path_str = str(image_path)
|
80
|
+
else:
|
81
|
+
image_path_str = image_path
|
35
82
|
|
36
83
|
if image_path_str.lower().endswith(".png"):
|
37
84
|
content_type = "image/png"
|
@@ -43,6 +90,52 @@ def image_to_base64(image_path: Union[str, Path]):
|
|
43
90
|
raise ValueError(
|
44
91
|
f"Unsupported image format for {image_path}; expected .png, .jpg, .jpeg, or .gif"
|
45
92
|
)
|
46
|
-
base64_image = f"data:{content_type};base64,{base64_only}"
|
47
93
|
|
48
|
-
return
|
94
|
+
return base64_data, content_type
|
95
|
+
|
96
|
+
|
97
|
+
def with_pdf_cleanup(embed_func: Callable[..., T]) -> Callable[..., T]:
|
98
|
+
"""Decorator to handle PDF to image conversion with automatic cleanup.
|
99
|
+
|
100
|
+
This decorator handles the common pattern of:
|
101
|
+
1. Converting PDF to images
|
102
|
+
2. Passing images to an embedding function
|
103
|
+
3. Cleaning up temporary files
|
104
|
+
|
105
|
+
Args:
|
106
|
+
embed_func: Function that takes a list of image paths and returns embeddings
|
107
|
+
|
108
|
+
Returns:
|
109
|
+
Callable that takes a PDF path and returns embeddings
|
110
|
+
"""
|
111
|
+
|
112
|
+
def wrapper(*args, **kwargs) -> T:
|
113
|
+
# First argument is self for instance methods
|
114
|
+
pdf_path = args[-1] if args else kwargs.get("pdf_path")
|
115
|
+
if not pdf_path:
|
116
|
+
raise ValueError(
|
117
|
+
"PDF path must be provided as the last positional argument or as 'pdf_path' keyword argument"
|
118
|
+
)
|
119
|
+
|
120
|
+
try:
|
121
|
+
images = pdf_to_images(pdf_path)
|
122
|
+
# Call the original function with the images instead of pdf_path
|
123
|
+
if args:
|
124
|
+
# For instance methods, replace the last argument (pdf_path) with images
|
125
|
+
args = list(args)
|
126
|
+
args[-1] = images
|
127
|
+
else:
|
128
|
+
kwargs["pdf_path"] = images
|
129
|
+
return embed_func(*args, **kwargs)
|
130
|
+
finally:
|
131
|
+
# Clean up temporary files created by pdf_to_images
|
132
|
+
for img_path in images:
|
133
|
+
try:
|
134
|
+
if img_path.exists() and str(img_path).startswith(
|
135
|
+
tempfile.gettempdir()
|
136
|
+
):
|
137
|
+
img_path.unlink()
|
138
|
+
except Exception as e:
|
139
|
+
logger.warning(f"Failed to clean up temporary file {img_path}: {e}")
|
140
|
+
|
141
|
+
return wrapper
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: embedkit
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.5
|
4
4
|
Summary: A simple toolkit for generating vector embeddings across multiple providers and models
|
5
5
|
Author-email: JP Hwang <me@jphwang.com>
|
6
6
|
License: MIT
|
@@ -22,7 +22,7 @@ Requires-Dist: colpali-engine<0.4.0,>=0.3.0
|
|
22
22
|
Requires-Dist: pdf2image>=1.17.0
|
23
23
|
Requires-Dist: pillow>=11.2.1
|
24
24
|
Requires-Dist: torch<=2.5
|
25
|
-
Requires-Dist: transformers
|
25
|
+
Requires-Dist: transformers>=4.46.2
|
26
26
|
Description-Content-Type: text/markdown
|
27
27
|
|
28
28
|
# EmbedKit
|
@@ -45,7 +45,7 @@ from embedkit.classes import Model, CohereInputType
|
|
45
45
|
|
46
46
|
# Initialize with ColPali
|
47
47
|
kit = EmbedKit.colpali(
|
48
|
-
model=Model.ColPali.
|
48
|
+
model=Model.ColPali.COLPALI_V1_3, # or COLSMOL_256M, COLSMOL_500M
|
49
49
|
text_batch_size=16, # Optional: process text in batches of 16
|
50
50
|
image_batch_size=8, # Optional: process images in batches of 8
|
51
51
|
)
|
@@ -54,7 +54,7 @@ kit = EmbedKit.colpali(
|
|
54
54
|
result = kit.embed_text("Hello world")
|
55
55
|
print(result.model_provider)
|
56
56
|
print(result.input_type)
|
57
|
-
print(result.objects[0].embedding.shape)
|
57
|
+
print(result.objects[0].embedding.shape) # Returns 2D array for ColPali
|
58
58
|
print(result.objects[0].source_b64)
|
59
59
|
|
60
60
|
# Initialize with Cohere
|
@@ -70,7 +70,7 @@ kit = EmbedKit.cohere(
|
|
70
70
|
result = kit.embed_text("Hello world")
|
71
71
|
print(result.model_provider)
|
72
72
|
print(result.input_type)
|
73
|
-
print(result.objects[0].embedding.shape)
|
73
|
+
print(result.objects[0].embedding.shape) # Returns 1D array for Cohere
|
74
74
|
print(result.objects[0].source_b64)
|
75
75
|
```
|
76
76
|
|
@@ -85,8 +85,8 @@ result = kit.embed_image(image_path)
|
|
85
85
|
|
86
86
|
print(result.model_provider)
|
87
87
|
print(result.input_type)
|
88
|
-
print(result.objects[0].embedding.shape)
|
89
|
-
print(result.objects[0].source_b64)
|
88
|
+
print(result.objects[0].embedding.shape) # 2D for ColPali, 1D for Cohere
|
89
|
+
print(result.objects[0].source_b64) # Base64 encoded image
|
90
90
|
```
|
91
91
|
|
92
92
|
### PDF Embeddings
|
@@ -100,8 +100,8 @@ result = kit.embed_pdf(pdf_path)
|
|
100
100
|
|
101
101
|
print(result.model_provider)
|
102
102
|
print(result.input_type)
|
103
|
-
print(result.objects[0].embedding.shape)
|
104
|
-
print(result.objects[0].source_b64)
|
103
|
+
print(result.objects[0].embedding.shape) # 2D for ColPali, 1D for Cohere
|
104
|
+
print(result.objects[0].source_b64) # Base64 encoded PDF page
|
105
105
|
```
|
106
106
|
|
107
107
|
## Response Format
|
@@ -116,17 +116,23 @@ class EmbeddingResponse:
|
|
116
116
|
objects: List[EmbeddingObject]
|
117
117
|
|
118
118
|
class EmbeddingObject:
|
119
|
-
embedding: np.ndarray
|
120
|
-
source_b64: Optional[str]
|
119
|
+
embedding: np.ndarray # 1D array for Cohere, 2D array for ColPali
|
120
|
+
source_b64: Optional[str] # Base64 encoded source for images and PDFs
|
121
121
|
```
|
122
122
|
|
123
123
|
## Supported Models
|
124
124
|
|
125
125
|
### ColPali
|
126
|
-
- `Model.ColPali.
|
126
|
+
- `Model.ColPali.COLPALI_V1_3`
|
127
|
+
- `Model.ColPali.COLSMOL_256M`
|
128
|
+
- `Model.ColPali.COLSMOL_500M`
|
127
129
|
|
128
130
|
### Cohere
|
129
131
|
- `Model.Cohere.EMBED_V4_0`
|
132
|
+
- `Model.Cohere.EMBED_ENGLISH_V3_0`
|
133
|
+
- `Model.Cohere.EMBED_ENGLISH_LIGHT_V3_0`
|
134
|
+
- `Model.Cohere.EMBED_MULTILINGUAL_V3_0`
|
135
|
+
- `Model.Cohere.EMBED_MULTILINGUAL_LIGHT_V3_0`
|
130
136
|
|
131
137
|
## Requirements
|
132
138
|
|
@@ -0,0 +1,13 @@
|
|
1
|
+
embedkit/__init__.py,sha256=ahMyC4SjYLr3QpM39obwDKpJO_HQ4_kqvG4d7jD5XtA,4503
|
2
|
+
embedkit/base.py,sha256=FfyzCqG7azs4yJj4RU6QKQRy9_wfd3KUFidThplmEo0,3739
|
3
|
+
embedkit/classes.py,sha256=LI5egTsfcBglUqIxBrDV_ymA1eKcxOUHJeMAhKVtU_Y,582
|
4
|
+
embedkit/config.py,sha256=EVGODSKxQAr46bU8dyORFunsfRuj6dnvtSqa4MxUZCo,138
|
5
|
+
embedkit/models.py,sha256=fndxt6QoV7nuyeXijs39K3QLs1l8ATw2PX2X-64oKt8,575
|
6
|
+
embedkit/utils.py,sha256=GV2dqNdI8XKP3iRD6zyoY5jlAcqVt1W5pxh3Rmn81Cc,4505
|
7
|
+
embedkit/providers/__init__.py,sha256=HaS-HNQabvhn9xLNZCq3VUqPCb7rGG4pvgvpKP4AXcw,201
|
8
|
+
embedkit/providers/cohere.py,sha256=sD9wsh8ifUYnuBbOeMklvBo49rlnsf-WU568W_fu8dM,4275
|
9
|
+
embedkit/providers/colpali.py,sha256=swdn8xHyn-Xob3POQFEbhwcjR_oLp7uVUMXc2gOggK8,5845
|
10
|
+
embedkit-0.1.5.dist-info/METADATA,sha256=hlUZDvuL7FXa1XjCZdZSftiT9f4rGc7DUoYkodAG94g,3871
|
11
|
+
embedkit-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
12
|
+
embedkit-0.1.5.dist-info/licenses/LICENSE,sha256=-g2Rad7b3rb2oVwOTwfMOIpscHT1zuaJoguamLRCBJs,1072
|
13
|
+
embedkit-0.1.5.dist-info/RECORD,,
|
embedkit-0.1.4.dist-info/RECORD
DELETED
@@ -1,13 +0,0 @@
|
|
1
|
-
embedkit/__init__.py,sha256=O7C_e30VMrLliDamyNM_srdyuGz50PvGzBJo9NgePoU,4555
|
2
|
-
embedkit/base.py,sha256=QkP-Ih7aYjIaxP5Hgp7MVypQlI5mhyOqLWKe0lOpIrQ,1344
|
3
|
-
embedkit/classes.py,sha256=TlCl58Vo8p0q8SfJGINtBmCXY2yVFLtdlfQqBKrxivw,600
|
4
|
-
embedkit/config.py,sha256=EVGODSKxQAr46bU8dyORFunsfRuj6dnvtSqa4MxUZCo,138
|
5
|
-
embedkit/models.py,sha256=EBIYkyZeIhGaOPL-9bslHHdLaZ7qzOYLd0qxVZ7VX7w,226
|
6
|
-
embedkit/utils.py,sha256=91BzzvbYSrUsWeW3CTAw3yK-M3S5FgQXov16gxffkUo,1572
|
7
|
-
embedkit/providers/__init__.py,sha256=HaS-HNQabvhn9xLNZCq3VUqPCb7rGG4pvgvpKP4AXcw,201
|
8
|
-
embedkit/providers/cohere.py,sha256=5-ux5UzpRqlRPm2PjgmRfcxH4NhRRjXa3cOZGWp0jDc,4880
|
9
|
-
embedkit/providers/colpali.py,sha256=PUjTESyF0r__JdgfT22kK1DGrf7XwYC2EuyBelspTsc,5500
|
10
|
-
embedkit-0.1.4.dist-info/METADATA,sha256=KPCFTWyeh_8unP1zBmVCUmA0c-yfXwxeMdWf9zO4CjA,3316
|
11
|
-
embedkit-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
12
|
-
embedkit-0.1.4.dist-info/licenses/LICENSE,sha256=-g2Rad7b3rb2oVwOTwfMOIpscHT1zuaJoguamLRCBJs,1072
|
13
|
-
embedkit-0.1.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|