embedkit 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- embedkit/__init__.py +10 -12
- embedkit/base.py +73 -4
- embedkit/classes.py +1 -6
- embedkit/models.py +7 -1
- embedkit/providers/cohere.py +31 -44
- embedkit/providers/colpali.py +69 -67
- embedkit/utils.py +113 -19
- {embedkit-0.1.4.dist-info → embedkit-0.1.6.dist-info}/METADATA +18 -12
- embedkit-0.1.6.dist-info/RECORD +13 -0
- embedkit-0.1.4.dist-info/RECORD +0 -13
- {embedkit-0.1.4.dist-info → embedkit-0.1.6.dist-info}/WHEEL +0 -0
- {embedkit-0.1.4.dist-info → embedkit-0.1.6.dist-info}/licenses/LICENSE +0 -0
embedkit/__init__.py
CHANGED
@@ -5,7 +5,6 @@ EmbedKit: A unified toolkit for generating vector embeddings.
|
|
5
5
|
|
6
6
|
from typing import Union, List, Optional
|
7
7
|
from pathlib import Path
|
8
|
-
import numpy as np
|
9
8
|
|
10
9
|
from .models import Model
|
11
10
|
from .base import EmbeddingError, EmbeddingResponse
|
@@ -28,7 +27,7 @@ class EmbedKit:
|
|
28
27
|
@classmethod
|
29
28
|
def colpali(
|
30
29
|
cls,
|
31
|
-
model: Model = Model.ColPali.
|
30
|
+
model: Model = Model.ColPali.COLPALI_V1_3,
|
32
31
|
device: Optional[str] = None,
|
33
32
|
text_batch_size: int = 32,
|
34
33
|
image_batch_size: int = 8,
|
@@ -42,13 +41,13 @@ class EmbedKit:
|
|
42
41
|
text_batch_size: Batch size for text embedding generation
|
43
42
|
image_batch_size: Batch size for image embedding generation
|
44
43
|
"""
|
45
|
-
if model
|
46
|
-
|
47
|
-
|
48
|
-
|
44
|
+
if not isinstance(model, Model.ColPali):
|
45
|
+
raise ValueError(
|
46
|
+
f"Unsupported model: {model}. Must be a Model.ColPali enum value."
|
47
|
+
)
|
49
48
|
|
50
49
|
provider = ColPaliProvider(
|
51
|
-
|
50
|
+
model=model,
|
52
51
|
device=device,
|
53
52
|
text_batch_size=text_batch_size,
|
54
53
|
image_batch_size=image_batch_size,
|
@@ -77,16 +76,15 @@ class EmbedKit:
|
|
77
76
|
if not api_key:
|
78
77
|
raise ValueError("API key is required")
|
79
78
|
|
80
|
-
if model
|
81
|
-
model_name = "embed-v4.0"
|
82
|
-
else:
|
79
|
+
if not isinstance(model, Model.Cohere):
|
83
80
|
raise ValueError(f"Unsupported model: {model}")
|
84
81
|
|
85
82
|
provider = CohereProvider(
|
86
|
-
api_key=api_key,
|
83
|
+
api_key=api_key,
|
84
|
+
model=model,
|
87
85
|
text_batch_size=text_batch_size,
|
88
86
|
image_batch_size=image_batch_size,
|
89
|
-
text_input_type=text_input_type
|
87
|
+
text_input_type=text_input_type,
|
90
88
|
)
|
91
89
|
return cls(provider)
|
92
90
|
|
embedkit/base.py
CHANGED
@@ -7,11 +7,15 @@ from pathlib import Path
|
|
7
7
|
import numpy as np
|
8
8
|
from dataclasses import dataclass
|
9
9
|
|
10
|
+
from .models import Model
|
11
|
+
from .utils import with_pdf_cleanup
|
12
|
+
|
10
13
|
|
11
14
|
@dataclass
|
12
15
|
class EmbeddingObject:
|
13
16
|
embedding: np.ndarray
|
14
17
|
source_b64: str = None
|
18
|
+
source_content_type: str = None # e.g., "image/png", "image/jpeg"
|
15
19
|
|
16
20
|
|
17
21
|
@dataclass
|
@@ -29,6 +33,67 @@ class EmbeddingResponse:
|
|
29
33
|
class EmbeddingProvider(ABC):
|
30
34
|
"""Abstract base class for embedding providers."""
|
31
35
|
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
model_name: str,
|
39
|
+
text_batch_size: int,
|
40
|
+
image_batch_size: int,
|
41
|
+
provider_name: str,
|
42
|
+
):
|
43
|
+
self.model_name = model_name
|
44
|
+
self.provider_name = provider_name
|
45
|
+
self.text_batch_size = text_batch_size
|
46
|
+
self.image_batch_size = image_batch_size
|
47
|
+
|
48
|
+
def _normalize_text_input(self, texts: Union[str, List[str]]) -> List[str]:
|
49
|
+
"""Normalize text input to a list of strings."""
|
50
|
+
if isinstance(texts, str):
|
51
|
+
return [texts]
|
52
|
+
return texts
|
53
|
+
|
54
|
+
def _normalize_image_input(
|
55
|
+
self, images: Union[Path, str, List[Union[Path, str]]]
|
56
|
+
) -> List[Path]:
|
57
|
+
"""Normalize image input to a list of Path objects."""
|
58
|
+
if isinstance(images, (str, Path)):
|
59
|
+
return [Path(images)]
|
60
|
+
return [Path(img) for img in images]
|
61
|
+
|
62
|
+
def _create_text_response(
|
63
|
+
self, embeddings: List[np.ndarray], input_type: str = "text"
|
64
|
+
) -> EmbeddingResponse:
|
65
|
+
"""Create a standardized text embedding response."""
|
66
|
+
return EmbeddingResponse(
|
67
|
+
model_name=self.model_name,
|
68
|
+
model_provider=self.provider_name,
|
69
|
+
input_type=input_type,
|
70
|
+
objects=[EmbeddingObject(embedding=e) for e in embeddings],
|
71
|
+
)
|
72
|
+
|
73
|
+
def _create_image_response(
|
74
|
+
self,
|
75
|
+
embeddings: List[np.ndarray],
|
76
|
+
b64_data: List[str],
|
77
|
+
content_types: List[str],
|
78
|
+
input_type: str = "image",
|
79
|
+
) -> EmbeddingResponse:
|
80
|
+
"""Create a standardized image embedding response."""
|
81
|
+
return EmbeddingResponse(
|
82
|
+
model_name=self.model_name,
|
83
|
+
model_provider=self.provider_name,
|
84
|
+
input_type=input_type,
|
85
|
+
objects=[
|
86
|
+
EmbeddingObject(
|
87
|
+
embedding=embedding,
|
88
|
+
source_b64=b64_data,
|
89
|
+
source_content_type=content_type,
|
90
|
+
)
|
91
|
+
for embedding, b64_data, content_type in zip(
|
92
|
+
embeddings, b64_data, content_types
|
93
|
+
)
|
94
|
+
],
|
95
|
+
)
|
96
|
+
|
32
97
|
@abstractmethod
|
33
98
|
def embed_text(self, texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
|
34
99
|
"""Generate document text embeddings using the configured provider."""
|
@@ -41,10 +106,14 @@ class EmbeddingProvider(ABC):
|
|
41
106
|
"""Generate image embeddings using the configured provider."""
|
42
107
|
pass
|
43
108
|
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
109
|
+
def embed_pdf(self, pdf_path: Path) -> EmbeddingResponse:
|
110
|
+
"""Generate embeddings for a PDF file."""
|
111
|
+
return self._embed_pdf_impl(pdf_path)
|
112
|
+
|
113
|
+
@with_pdf_cleanup
|
114
|
+
def _embed_pdf_impl(self, pdf_path: List[Path]) -> EmbeddingResponse:
|
115
|
+
"""Internal implementation of PDF embedding with cleanup handled by decorator."""
|
116
|
+
return self.embed_image(pdf_path)
|
48
117
|
|
49
118
|
|
50
119
|
class EmbeddingError(Exception):
|
embedkit/classes.py
CHANGED
@@ -13,9 +13,4 @@ from . import EmbeddingResponse, EmbeddingError
|
|
13
13
|
from .models import Model
|
14
14
|
from .providers.cohere import CohereInputType
|
15
15
|
|
16
|
-
__all__ = [
|
17
|
-
"EmbeddingResponse",
|
18
|
-
"EmbeddingError",
|
19
|
-
"Model",
|
20
|
-
"CohereInputType"
|
21
|
-
]
|
16
|
+
__all__ = ["EmbeddingResponse", "EmbeddingError", "Model", "CohereInputType"]
|
embedkit/models.py
CHANGED
@@ -6,7 +6,13 @@ from enum import Enum
|
|
6
6
|
|
7
7
|
class Model:
|
8
8
|
class ColPali(Enum):
|
9
|
-
|
9
|
+
COLPALI_V1_3 = "vidore/colpali-v1.3"
|
10
|
+
COLSMOL_500M = "vidore/colSmol-500M"
|
11
|
+
COLSMOL_256M = "vidore/colSmol-256M"
|
10
12
|
|
11
13
|
class Cohere(Enum):
|
12
14
|
EMBED_V4_0 = "embed-v4.0"
|
15
|
+
EMBED_ENGLISH_V3_0 = "embed-english-v3.0"
|
16
|
+
EMBED_ENGLISH_LIGHT_V3_0 = "embed-english-light-v3.0"
|
17
|
+
EMBED_MULTILINGUAL_V3_0 = "embed-multilingual-v3.0"
|
18
|
+
EMBED_MULTILINGUAL_LIGHT_V3_0 = "embed-multilingual-light-v3.0"
|
embedkit/providers/cohere.py
CHANGED
@@ -5,9 +5,13 @@ from typing import Union, List
|
|
5
5
|
from pathlib import Path
|
6
6
|
import numpy as np
|
7
7
|
from enum import Enum
|
8
|
+
import logging
|
8
9
|
|
9
|
-
from ..
|
10
|
-
from ..
|
10
|
+
from ..models import Model
|
11
|
+
from ..utils import image_to_base64
|
12
|
+
from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResponse
|
13
|
+
|
14
|
+
logger = logging.getLogger(__name__)
|
11
15
|
|
12
16
|
|
13
17
|
class CohereInputType(Enum):
|
@@ -23,18 +27,20 @@ class CohereProvider(EmbeddingProvider):
|
|
23
27
|
def __init__(
|
24
28
|
self,
|
25
29
|
api_key: str,
|
26
|
-
|
30
|
+
model: Model.Cohere,
|
27
31
|
text_batch_size: int,
|
28
32
|
image_batch_size: int,
|
29
33
|
text_input_type: CohereInputType = CohereInputType.SEARCH_DOCUMENT,
|
30
34
|
):
|
35
|
+
super().__init__(
|
36
|
+
model_name=model.value,
|
37
|
+
text_batch_size=text_batch_size,
|
38
|
+
image_batch_size=image_batch_size,
|
39
|
+
provider_name="Cohere",
|
40
|
+
)
|
31
41
|
self.api_key = api_key
|
32
|
-
self.model_name = model_name
|
33
|
-
self.text_batch_size = text_batch_size
|
34
|
-
self.image_batch_size = image_batch_size
|
35
42
|
self.input_type = text_input_type
|
36
43
|
self._client = None
|
37
|
-
self.provider_name = "Cohere"
|
38
44
|
|
39
45
|
def _get_client(self):
|
40
46
|
"""Lazy load the Cohere client."""
|
@@ -54,9 +60,7 @@ class CohereProvider(EmbeddingProvider):
|
|
54
60
|
def embed_text(self, texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
|
55
61
|
"""Generate text embeddings using the Cohere API."""
|
56
62
|
client = self._get_client()
|
57
|
-
|
58
|
-
if isinstance(texts, str):
|
59
|
-
texts = [texts]
|
63
|
+
texts = self._normalize_text_input(texts)
|
60
64
|
|
61
65
|
try:
|
62
66
|
all_embeddings = []
|
@@ -72,16 +76,7 @@ class CohereProvider(EmbeddingProvider):
|
|
72
76
|
)
|
73
77
|
all_embeddings.extend(np.array(response.embeddings.float_))
|
74
78
|
|
75
|
-
return
|
76
|
-
model_name=self.model_name,
|
77
|
-
model_provider=self.provider_name,
|
78
|
-
input_type=self.input_type.value,
|
79
|
-
objects=[
|
80
|
-
EmbeddingObject(
|
81
|
-
embedding=e,
|
82
|
-
) for e in all_embeddings
|
83
|
-
]
|
84
|
-
)
|
79
|
+
return self._create_text_response(all_embeddings, self.input_type.value)
|
85
80
|
|
86
81
|
except Exception as e:
|
87
82
|
raise EmbeddingError(f"Failed to embed text with Cohere: {e}") from e
|
@@ -92,12 +87,9 @@ class CohereProvider(EmbeddingProvider):
|
|
92
87
|
) -> EmbeddingResponse:
|
93
88
|
"""Generate embeddings for images using Cohere API."""
|
94
89
|
client = self._get_client()
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
images = [Path(images)]
|
99
|
-
else:
|
100
|
-
images = [Path(img) for img in images]
|
90
|
+
images = self._normalize_image_input(images)
|
91
|
+
total_images = len(images)
|
92
|
+
logger.info(f"Starting to process {total_images} images")
|
101
93
|
|
102
94
|
try:
|
103
95
|
all_embeddings = []
|
@@ -106,12 +98,17 @@ class CohereProvider(EmbeddingProvider):
|
|
106
98
|
# Process images in batches
|
107
99
|
for i in range(0, len(images), self.image_batch_size):
|
108
100
|
batch_images = images[i : i + self.image_batch_size]
|
101
|
+
logger.info(f"Processing batch {i//self.image_batch_size + 1} of {(total_images + self.image_batch_size - 1)//self.image_batch_size} ({len(batch_images)} images)")
|
109
102
|
b64_images = []
|
110
103
|
|
111
104
|
for image in batch_images:
|
112
105
|
if not image.exists():
|
113
106
|
raise EmbeddingError(f"Image not found: {image}")
|
114
|
-
|
107
|
+
b64_data, content_type = image_to_base64(image)
|
108
|
+
# Construct full data URI for API
|
109
|
+
data_uri = f"data:{content_type};base64,{b64_data}"
|
110
|
+
b64_images.append(data_uri)
|
111
|
+
all_b64_images.append((b64_data, content_type))
|
115
112
|
|
116
113
|
response = client.embed(
|
117
114
|
model=self.model_name,
|
@@ -121,24 +118,14 @@ class CohereProvider(EmbeddingProvider):
|
|
121
118
|
)
|
122
119
|
|
123
120
|
all_embeddings.extend(np.array(response.embeddings.float_))
|
124
|
-
|
125
|
-
|
126
|
-
return
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
objects=[
|
131
|
-
EmbeddingObject(
|
132
|
-
embedding=all_embeddings[i],
|
133
|
-
source_b64=all_b64_images[i]
|
134
|
-
) for i in range(len(all_embeddings))
|
135
|
-
]
|
121
|
+
|
122
|
+
logger.info(f"Successfully processed all {total_images} images")
|
123
|
+
return self._create_image_response(
|
124
|
+
all_embeddings,
|
125
|
+
[b64 for b64, _ in all_b64_images],
|
126
|
+
[content_type for _, content_type in all_b64_images],
|
136
127
|
)
|
137
128
|
|
138
129
|
except Exception as e:
|
130
|
+
logger.error(f"Failed to embed images: {e}")
|
139
131
|
raise EmbeddingError(f"Failed to embed image with Cohere: {e}") from e
|
140
|
-
|
141
|
-
def embed_pdf(self, pdf_path: Path) -> EmbeddingResponse:
|
142
|
-
"""Generate embeddings for a PDF file using Cohere API."""
|
143
|
-
image_paths = pdf_to_images(pdf_path)
|
144
|
-
return self.embed_image(image_paths)
|
embedkit/providers/colpali.py
CHANGED
@@ -8,8 +8,9 @@ import numpy as np
|
|
8
8
|
import torch
|
9
9
|
from PIL import Image
|
10
10
|
|
11
|
-
from ..
|
12
|
-
from ..
|
11
|
+
from ..models import Model
|
12
|
+
from ..utils import image_to_base64
|
13
|
+
from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResponse
|
13
14
|
|
14
15
|
logger = logging.getLogger(__name__)
|
15
16
|
|
@@ -19,15 +20,17 @@ class ColPaliProvider(EmbeddingProvider):
|
|
19
20
|
|
20
21
|
def __init__(
|
21
22
|
self,
|
22
|
-
|
23
|
+
model: Model.ColPali,
|
23
24
|
text_batch_size: int,
|
24
25
|
image_batch_size: int,
|
25
26
|
device: Optional[str] = None,
|
26
27
|
):
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
28
|
+
super().__init__(
|
29
|
+
model_name=model.value,
|
30
|
+
text_batch_size=text_batch_size,
|
31
|
+
image_batch_size=image_batch_size,
|
32
|
+
provider_name="ColPali",
|
33
|
+
)
|
31
34
|
|
32
35
|
# Auto-detect device
|
33
36
|
if device is None:
|
@@ -38,24 +41,43 @@ class ColPaliProvider(EmbeddingProvider):
|
|
38
41
|
else:
|
39
42
|
device = "cpu"
|
40
43
|
|
41
|
-
self.
|
42
|
-
self.
|
43
|
-
self.
|
44
|
+
self._hf_device = device
|
45
|
+
self._hf_model = None
|
46
|
+
self._hf_processor = None
|
44
47
|
|
45
48
|
def _load_model(self):
|
46
49
|
"""Lazy load the model."""
|
47
|
-
if self.
|
50
|
+
if self._hf_model is None:
|
48
51
|
try:
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
self.
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
52
|
+
if self.model_name in [Model.ColPali.COLPALI_V1_3.value]:
|
53
|
+
from colpali_engine.models import ColPali, ColPaliProcessor
|
54
|
+
|
55
|
+
self._hf_model = ColPali.from_pretrained(
|
56
|
+
self.model_name,
|
57
|
+
torch_dtype=torch.bfloat16,
|
58
|
+
device_map=self._hf_device,
|
59
|
+
).eval()
|
60
|
+
|
61
|
+
self._hf_processor = ColPaliProcessor.from_pretrained(self.model_name)
|
62
|
+
|
63
|
+
elif self.model_name in [
|
64
|
+
Model.ColPali.COLSMOL_500M.value,
|
65
|
+
Model.ColPali.COLSMOL_256M.value,
|
66
|
+
]:
|
67
|
+
from colpali_engine.models import ColIdefics3, ColIdefics3Processor
|
68
|
+
|
69
|
+
self._hf_model = ColIdefics3.from_pretrained(
|
70
|
+
self.model_name,
|
71
|
+
torch_dtype=torch.bfloat16,
|
72
|
+
device_map=self._hf_device,
|
73
|
+
).eval()
|
74
|
+
self._hf_processor = ColIdefics3Processor.from_pretrained(
|
75
|
+
self.model_name
|
76
|
+
)
|
77
|
+
else:
|
78
|
+
raise ValueError(f"Unable to load model for: {self.model_name}.")
|
79
|
+
|
80
|
+
logger.info(f"Loaded {self.model_name} on {self._hf_device}")
|
59
81
|
|
60
82
|
except ImportError as e:
|
61
83
|
raise EmbeddingError(
|
@@ -64,38 +86,26 @@ class ColPaliProvider(EmbeddingProvider):
|
|
64
86
|
except Exception as e:
|
65
87
|
raise EmbeddingError(f"Failed to load model: {e}") from e
|
66
88
|
|
67
|
-
def embed_text(self, texts: Union[str, List[str]]) -> EmbeddingResponse:
|
89
|
+
def embed_text(self, texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
|
68
90
|
"""Generate embeddings for text inputs."""
|
69
91
|
self._load_model()
|
70
|
-
|
71
|
-
if isinstance(texts, str):
|
72
|
-
texts = [texts]
|
92
|
+
texts = self._normalize_text_input(texts)
|
73
93
|
|
74
94
|
try:
|
75
95
|
# Process texts in batches
|
76
|
-
all_embeddings = []
|
96
|
+
all_embeddings: List[np.ndarray] = []
|
77
97
|
|
78
98
|
for i in range(0, len(texts), self.text_batch_size):
|
79
99
|
batch_texts = texts[i : i + self.text_batch_size]
|
80
|
-
processed = self.
|
100
|
+
processed = self._hf_processor.process_queries(batch_texts).to(self._hf_device)
|
81
101
|
|
82
102
|
with torch.no_grad():
|
83
|
-
batch_embeddings = self.
|
103
|
+
batch_embeddings = self._hf_model(**processed)
|
84
104
|
all_embeddings.append(batch_embeddings.cpu().float().numpy())
|
85
105
|
|
86
106
|
# Concatenate all batch embeddings
|
87
107
|
final_embeddings = np.concatenate(all_embeddings, axis=0)
|
88
|
-
|
89
|
-
return EmbeddingResponse(
|
90
|
-
model_name=self.model_name,
|
91
|
-
model_provider=self.provider_name,
|
92
|
-
input_type="text",
|
93
|
-
objects=[
|
94
|
-
EmbeddingObject(
|
95
|
-
embedding=e,
|
96
|
-
) for e in final_embeddings
|
97
|
-
]
|
98
|
-
)
|
108
|
+
return self._create_text_response(final_embeddings)
|
99
109
|
|
100
110
|
except Exception as e:
|
101
111
|
raise EmbeddingError(f"Failed to embed text: {e}") from e
|
@@ -105,21 +115,22 @@ class ColPaliProvider(EmbeddingProvider):
|
|
105
115
|
) -> EmbeddingResponse:
|
106
116
|
"""Generate embeddings for images."""
|
107
117
|
self._load_model()
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
else:
|
112
|
-
images = [Path(img) for img in images]
|
118
|
+
images = self._normalize_image_input(images)
|
119
|
+
total_images = len(images)
|
120
|
+
logger.info(f"Starting to process {total_images} images")
|
113
121
|
|
114
122
|
try:
|
115
123
|
# Process images in batches
|
116
|
-
all_embeddings = []
|
117
|
-
|
124
|
+
all_embeddings: List[np.ndarray] = []
|
125
|
+
all_b64_data: List[str] = []
|
126
|
+
all_content_types: List[str] = []
|
118
127
|
|
119
128
|
for i in range(0, len(images), self.image_batch_size):
|
120
129
|
batch_images = images[i : i + self.image_batch_size]
|
130
|
+
logger.info(f"Processing batch {i//self.image_batch_size + 1} of {(total_images + self.image_batch_size - 1)//self.image_batch_size} ({len(batch_images)} images)")
|
121
131
|
pil_images = []
|
122
|
-
|
132
|
+
batch_b64_data = []
|
133
|
+
batch_content_types = []
|
123
134
|
|
124
135
|
for img_path in batch_images:
|
125
136
|
if not img_path.exists():
|
@@ -127,34 +138,25 @@ class ColPaliProvider(EmbeddingProvider):
|
|
127
138
|
|
128
139
|
with Image.open(img_path) as img:
|
129
140
|
pil_images.append(img.convert("RGB"))
|
130
|
-
|
141
|
+
b64, content_type = image_to_base64(img_path)
|
142
|
+
batch_b64_data.append(b64)
|
143
|
+
batch_content_types.append(content_type)
|
131
144
|
|
132
|
-
processed = self.
|
145
|
+
processed = self._hf_processor.process_images(pil_images).to(self._hf_device)
|
133
146
|
|
134
147
|
with torch.no_grad():
|
135
|
-
batch_embeddings = self.
|
148
|
+
batch_embeddings = self._hf_model(**processed)
|
136
149
|
all_embeddings.append(batch_embeddings.cpu().float().numpy())
|
137
|
-
|
150
|
+
all_b64_data.extend(batch_b64_data)
|
151
|
+
all_content_types.extend(batch_content_types)
|
138
152
|
|
139
153
|
# Concatenate all batch embeddings
|
140
154
|
final_embeddings = np.concatenate(all_embeddings, axis=0)
|
141
|
-
|
142
|
-
return
|
143
|
-
|
144
|
-
model_provider=self.provider_name,
|
145
|
-
input_type="image",
|
146
|
-
objects=[
|
147
|
-
EmbeddingObject(
|
148
|
-
embedding=final_embeddings[i],
|
149
|
-
source_b64=all_b64_images[i]
|
150
|
-
) for i in range(len(final_embeddings))
|
151
|
-
]
|
155
|
+
logger.info(f"Successfully processed all {total_images} images")
|
156
|
+
return self._create_image_response(
|
157
|
+
final_embeddings, all_b64_data, all_content_types
|
152
158
|
)
|
153
159
|
|
154
160
|
except Exception as e:
|
161
|
+
logger.error(f"Failed to embed images: {e}")
|
155
162
|
raise EmbeddingError(f"Failed to embed images: {e}") from e
|
156
|
-
|
157
|
-
def embed_pdf(self, pdf_path: Path) -> EmbeddingResponse:
|
158
|
-
"""Generate embeddings for a PDF file using ColPali API."""
|
159
|
-
images = pdf_to_images(pdf_path)
|
160
|
-
return self.embed_image(images)
|
embedkit/utils.py
CHANGED
@@ -1,37 +1,84 @@
|
|
1
|
+
import tempfile
|
2
|
+
import shutil
|
3
|
+
import logging
|
4
|
+
from contextlib import contextmanager
|
1
5
|
from pdf2image import convert_from_path
|
2
6
|
from pathlib import Path
|
3
7
|
from .config import get_temp_dir
|
4
|
-
from typing import Union
|
8
|
+
from typing import Union, List, Iterator, Callable, TypeVar, Any
|
5
9
|
|
10
|
+
logger = logging.getLogger(__name__)
|
6
11
|
|
7
|
-
|
8
|
-
"""Convert a PDF file to a list of images."""
|
9
|
-
root_temp_dir = get_temp_dir()
|
10
|
-
img_temp_dir = root_temp_dir / "images"
|
11
|
-
img_temp_dir.mkdir(parents=True, exist_ok=True)
|
12
|
-
images = convert_from_path(pdf_path=str(pdf_path), output_folder=str(img_temp_dir))
|
13
|
-
image_paths = []
|
12
|
+
T = TypeVar("T")
|
14
13
|
|
15
|
-
for i, image in enumerate(images):
|
16
|
-
output_path = img_temp_dir / f"{pdf_path.stem}_{i}.png"
|
17
|
-
if output_path.exists():
|
18
|
-
output_path.unlink()
|
19
14
|
|
20
|
-
|
21
|
-
|
22
|
-
|
15
|
+
@contextmanager
|
16
|
+
def temporary_directory() -> Iterator[Path]:
|
17
|
+
"""Create a temporary directory that is automatically cleaned up when done.
|
23
18
|
|
19
|
+
Yields:
|
20
|
+
Path: Path to the temporary directory
|
21
|
+
"""
|
22
|
+
temp_dir = Path(tempfile.mkdtemp())
|
23
|
+
try:
|
24
|
+
yield temp_dir
|
25
|
+
finally:
|
26
|
+
shutil.rmtree(temp_dir)
|
27
|
+
|
28
|
+
|
29
|
+
def pdf_to_images(pdf_path: Path) -> List[Path]:
|
30
|
+
"""Convert a PDF file to a list of images.
|
31
|
+
|
32
|
+
The images are stored in a temporary directory that will be automatically
|
33
|
+
cleaned up when the process exits.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
pdf_path: Path to the PDF file
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
List[Path]: List of paths to the generated images
|
40
|
+
|
41
|
+
Note:
|
42
|
+
The temporary files will be automatically cleaned up when the process exits.
|
43
|
+
Do not rely on these files persisting after the function returns.
|
44
|
+
"""
|
45
|
+
with temporary_directory() as temp_dir:
|
46
|
+
images = convert_from_path(pdf_path=str(pdf_path), output_folder=str(temp_dir))
|
47
|
+
image_paths = []
|
48
|
+
|
49
|
+
for i, image in enumerate(images):
|
50
|
+
output_path = temp_dir / f"{pdf_path.stem}_{i}.png"
|
51
|
+
image.save(output_path)
|
52
|
+
final_path = Path(tempfile.mktemp(suffix=".png"))
|
53
|
+
shutil.move(output_path, final_path)
|
54
|
+
image_paths.append(final_path)
|
55
|
+
|
56
|
+
return image_paths
|
57
|
+
|
58
|
+
|
59
|
+
def image_to_base64(image_path: Union[str, Path]) -> tuple[str, str]:
|
60
|
+
"""Convert an image to base64 and return the base64 data and content type.
|
24
61
|
|
25
|
-
|
62
|
+
Args:
|
63
|
+
image_path: Path to the image file
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
tuple[str, str]: (base64_data, content_type)
|
67
|
+
|
68
|
+
Raises:
|
69
|
+
ValueError: If the image cannot be read or has an unsupported format
|
70
|
+
"""
|
26
71
|
import base64
|
27
72
|
|
28
73
|
try:
|
29
|
-
|
74
|
+
base64_data = base64.b64encode(Path(image_path).read_bytes()).decode("utf-8")
|
30
75
|
except Exception as e:
|
31
76
|
raise ValueError(f"Failed to read image {image_path}: {e}") from e
|
32
77
|
|
33
78
|
if isinstance(image_path, Path):
|
34
79
|
image_path_str = str(image_path)
|
80
|
+
else:
|
81
|
+
image_path_str = image_path
|
35
82
|
|
36
83
|
if image_path_str.lower().endswith(".png"):
|
37
84
|
content_type = "image/png"
|
@@ -43,6 +90,53 @@ def image_to_base64(image_path: Union[str, Path]):
|
|
43
90
|
raise ValueError(
|
44
91
|
f"Unsupported image format for {image_path}; expected .png, .jpg, .jpeg, or .gif"
|
45
92
|
)
|
46
|
-
base64_image = f"data:{content_type};base64,{base64_only}"
|
47
93
|
|
48
|
-
return
|
94
|
+
return base64_data, content_type
|
95
|
+
|
96
|
+
|
97
|
+
def with_pdf_cleanup(embed_func: Callable[..., T]) -> Callable[..., T]:
|
98
|
+
"""Decorator to handle PDF to image conversion with automatic cleanup.
|
99
|
+
|
100
|
+
This decorator handles the common pattern of:
|
101
|
+
1. Converting PDF to images
|
102
|
+
2. Passing images to an embedding function
|
103
|
+
3. Cleaning up temporary files
|
104
|
+
|
105
|
+
Args:
|
106
|
+
embed_func: Function that takes a list of image paths and returns embeddings
|
107
|
+
|
108
|
+
Returns:
|
109
|
+
Callable that takes a PDF path and returns embeddings
|
110
|
+
"""
|
111
|
+
|
112
|
+
def wrapper(*args, **kwargs) -> T:
|
113
|
+
# First argument is self for instance methods
|
114
|
+
pdf_path = args[-1] if args else kwargs.get("pdf_path")
|
115
|
+
if not pdf_path:
|
116
|
+
raise ValueError(
|
117
|
+
"PDF path must be provided as the last positional argument or as 'pdf_path' keyword argument"
|
118
|
+
)
|
119
|
+
|
120
|
+
images = [] # Initialize images as empty list
|
121
|
+
try:
|
122
|
+
images = pdf_to_images(pdf_path)
|
123
|
+
# Call the original function with the images instead of pdf_path
|
124
|
+
if args:
|
125
|
+
# For instance methods, replace the last argument (pdf_path) with images
|
126
|
+
args = list(args)
|
127
|
+
args[-1] = images
|
128
|
+
else:
|
129
|
+
kwargs["pdf_path"] = images
|
130
|
+
return embed_func(*args, **kwargs)
|
131
|
+
finally:
|
132
|
+
# Clean up temporary files created by pdf_to_images
|
133
|
+
for img_path in images:
|
134
|
+
try:
|
135
|
+
if img_path.exists() and str(img_path).startswith(
|
136
|
+
tempfile.gettempdir()
|
137
|
+
):
|
138
|
+
img_path.unlink()
|
139
|
+
except Exception as e:
|
140
|
+
logger.warning(f"Failed to clean up temporary file {img_path}: {e}")
|
141
|
+
|
142
|
+
return wrapper
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: embedkit
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.6
|
4
4
|
Summary: A simple toolkit for generating vector embeddings across multiple providers and models
|
5
5
|
Author-email: JP Hwang <me@jphwang.com>
|
6
6
|
License: MIT
|
@@ -22,7 +22,7 @@ Requires-Dist: colpali-engine<0.4.0,>=0.3.0
|
|
22
22
|
Requires-Dist: pdf2image>=1.17.0
|
23
23
|
Requires-Dist: pillow>=11.2.1
|
24
24
|
Requires-Dist: torch<=2.5
|
25
|
-
Requires-Dist: transformers
|
25
|
+
Requires-Dist: transformers>=4.46.2
|
26
26
|
Description-Content-Type: text/markdown
|
27
27
|
|
28
28
|
# EmbedKit
|
@@ -45,7 +45,7 @@ from embedkit.classes import Model, CohereInputType
|
|
45
45
|
|
46
46
|
# Initialize with ColPali
|
47
47
|
kit = EmbedKit.colpali(
|
48
|
-
model=Model.ColPali.
|
48
|
+
model=Model.ColPali.COLPALI_V1_3, # or COLSMOL_256M, COLSMOL_500M
|
49
49
|
text_batch_size=16, # Optional: process text in batches of 16
|
50
50
|
image_batch_size=8, # Optional: process images in batches of 8
|
51
51
|
)
|
@@ -54,7 +54,7 @@ kit = EmbedKit.colpali(
|
|
54
54
|
result = kit.embed_text("Hello world")
|
55
55
|
print(result.model_provider)
|
56
56
|
print(result.input_type)
|
57
|
-
print(result.objects[0].embedding.shape)
|
57
|
+
print(result.objects[0].embedding.shape) # Returns 2D array for ColPali
|
58
58
|
print(result.objects[0].source_b64)
|
59
59
|
|
60
60
|
# Initialize with Cohere
|
@@ -70,7 +70,7 @@ kit = EmbedKit.cohere(
|
|
70
70
|
result = kit.embed_text("Hello world")
|
71
71
|
print(result.model_provider)
|
72
72
|
print(result.input_type)
|
73
|
-
print(result.objects[0].embedding.shape)
|
73
|
+
print(result.objects[0].embedding.shape) # Returns 1D array for Cohere
|
74
74
|
print(result.objects[0].source_b64)
|
75
75
|
```
|
76
76
|
|
@@ -85,8 +85,8 @@ result = kit.embed_image(image_path)
|
|
85
85
|
|
86
86
|
print(result.model_provider)
|
87
87
|
print(result.input_type)
|
88
|
-
print(result.objects[0].embedding.shape)
|
89
|
-
print(result.objects[0].source_b64)
|
88
|
+
print(result.objects[0].embedding.shape) # 2D for ColPali, 1D for Cohere
|
89
|
+
print(result.objects[0].source_b64) # Base64 encoded image
|
90
90
|
```
|
91
91
|
|
92
92
|
### PDF Embeddings
|
@@ -100,8 +100,8 @@ result = kit.embed_pdf(pdf_path)
|
|
100
100
|
|
101
101
|
print(result.model_provider)
|
102
102
|
print(result.input_type)
|
103
|
-
print(result.objects[0].embedding.shape)
|
104
|
-
print(result.objects[0].source_b64)
|
103
|
+
print(result.objects[0].embedding.shape) # 2D for ColPali, 1D for Cohere
|
104
|
+
print(result.objects[0].source_b64) # Base64 encoded PDF page
|
105
105
|
```
|
106
106
|
|
107
107
|
## Response Format
|
@@ -116,17 +116,23 @@ class EmbeddingResponse:
|
|
116
116
|
objects: List[EmbeddingObject]
|
117
117
|
|
118
118
|
class EmbeddingObject:
|
119
|
-
embedding: np.ndarray
|
120
|
-
source_b64: Optional[str]
|
119
|
+
embedding: np.ndarray # 1D array for Cohere, 2D array for ColPali
|
120
|
+
source_b64: Optional[str] # Base64 encoded source for images and PDFs
|
121
121
|
```
|
122
122
|
|
123
123
|
## Supported Models
|
124
124
|
|
125
125
|
### ColPali
|
126
|
-
- `Model.ColPali.
|
126
|
+
- `Model.ColPali.COLPALI_V1_3`
|
127
|
+
- `Model.ColPali.COLSMOL_256M`
|
128
|
+
- `Model.ColPali.COLSMOL_500M`
|
127
129
|
|
128
130
|
### Cohere
|
129
131
|
- `Model.Cohere.EMBED_V4_0`
|
132
|
+
- `Model.Cohere.EMBED_ENGLISH_V3_0`
|
133
|
+
- `Model.Cohere.EMBED_ENGLISH_LIGHT_V3_0`
|
134
|
+
- `Model.Cohere.EMBED_MULTILINGUAL_V3_0`
|
135
|
+
- `Model.Cohere.EMBED_MULTILINGUAL_LIGHT_V3_0`
|
130
136
|
|
131
137
|
## Requirements
|
132
138
|
|
@@ -0,0 +1,13 @@
|
|
1
|
+
embedkit/__init__.py,sha256=ahMyC4SjYLr3QpM39obwDKpJO_HQ4_kqvG4d7jD5XtA,4503
|
2
|
+
embedkit/base.py,sha256=FfyzCqG7azs4yJj4RU6QKQRy9_wfd3KUFidThplmEo0,3739
|
3
|
+
embedkit/classes.py,sha256=LI5egTsfcBglUqIxBrDV_ymA1eKcxOUHJeMAhKVtU_Y,582
|
4
|
+
embedkit/config.py,sha256=EVGODSKxQAr46bU8dyORFunsfRuj6dnvtSqa4MxUZCo,138
|
5
|
+
embedkit/models.py,sha256=fndxt6QoV7nuyeXijs39K3QLs1l8ATw2PX2X-64oKt8,575
|
6
|
+
embedkit/utils.py,sha256=xqIjwkWDgDr1mCtbb5UbZo_Ci50griubxYJq3pSBv34,4560
|
7
|
+
embedkit/providers/__init__.py,sha256=HaS-HNQabvhn9xLNZCq3VUqPCb7rGG4pvgvpKP4AXcw,201
|
8
|
+
embedkit/providers/cohere.py,sha256=B5hgOIY_rexgc0s_DWjZ-yuuOXaYLQV-jmrY-gdBhz8,4726
|
9
|
+
embedkit/providers/colpali.py,sha256=8LsV-kkBmNYzCGXWeZy-KLEyMOxTvzOAS4gnzN5q_f4,6260
|
10
|
+
embedkit-0.1.6.dist-info/METADATA,sha256=tf0SLNvH5OOYFudiMmgEALYE6fnV716WJO0pPZIQJPw,3871
|
11
|
+
embedkit-0.1.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
12
|
+
embedkit-0.1.6.dist-info/licenses/LICENSE,sha256=-g2Rad7b3rb2oVwOTwfMOIpscHT1zuaJoguamLRCBJs,1072
|
13
|
+
embedkit-0.1.6.dist-info/RECORD,,
|
embedkit-0.1.4.dist-info/RECORD
DELETED
@@ -1,13 +0,0 @@
|
|
1
|
-
embedkit/__init__.py,sha256=O7C_e30VMrLliDamyNM_srdyuGz50PvGzBJo9NgePoU,4555
|
2
|
-
embedkit/base.py,sha256=QkP-Ih7aYjIaxP5Hgp7MVypQlI5mhyOqLWKe0lOpIrQ,1344
|
3
|
-
embedkit/classes.py,sha256=TlCl58Vo8p0q8SfJGINtBmCXY2yVFLtdlfQqBKrxivw,600
|
4
|
-
embedkit/config.py,sha256=EVGODSKxQAr46bU8dyORFunsfRuj6dnvtSqa4MxUZCo,138
|
5
|
-
embedkit/models.py,sha256=EBIYkyZeIhGaOPL-9bslHHdLaZ7qzOYLd0qxVZ7VX7w,226
|
6
|
-
embedkit/utils.py,sha256=91BzzvbYSrUsWeW3CTAw3yK-M3S5FgQXov16gxffkUo,1572
|
7
|
-
embedkit/providers/__init__.py,sha256=HaS-HNQabvhn9xLNZCq3VUqPCb7rGG4pvgvpKP4AXcw,201
|
8
|
-
embedkit/providers/cohere.py,sha256=5-ux5UzpRqlRPm2PjgmRfcxH4NhRRjXa3cOZGWp0jDc,4880
|
9
|
-
embedkit/providers/colpali.py,sha256=PUjTESyF0r__JdgfT22kK1DGrf7XwYC2EuyBelspTsc,5500
|
10
|
-
embedkit-0.1.4.dist-info/METADATA,sha256=KPCFTWyeh_8unP1zBmVCUmA0c-yfXwxeMdWf9zO4CjA,3316
|
11
|
-
embedkit-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
12
|
-
embedkit-0.1.4.dist-info/licenses/LICENSE,sha256=-g2Rad7b3rb2oVwOTwfMOIpscHT1zuaJoguamLRCBJs,1072
|
13
|
-
embedkit-0.1.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|