embedkit 0.1.1__tar.gz → 0.1.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {embedkit-0.1.1 → embedkit-0.1.3}/PKG-INFO +1 -1
- {embedkit-0.1.1 → embedkit-0.1.3}/main.py +17 -3
- {embedkit-0.1.1 → embedkit-0.1.3}/pyproject.toml +1 -1
- {embedkit-0.1.1 → embedkit-0.1.3}/src/embedkit/__init__.py +23 -4
- {embedkit-0.1.1 → embedkit-0.1.3}/src/embedkit/base.py +1 -3
- embedkit-0.1.3/src/embedkit/classes.py +21 -0
- {embedkit-0.1.1 → embedkit-0.1.3}/src/embedkit/providers/cohere.py +42 -21
- {embedkit-0.1.1 → embedkit-0.1.3}/src/embedkit/providers/colpali.py +44 -21
- {embedkit-0.1.1 → embedkit-0.1.3}/src/embedkit/utils.py +2 -6
- {embedkit-0.1.1 → embedkit-0.1.3}/.gitignore +0 -0
- {embedkit-0.1.1 → embedkit-0.1.3}/.python-version +0 -0
- {embedkit-0.1.1 → embedkit-0.1.3}/LICENSE +0 -0
- {embedkit-0.1.1 → embedkit-0.1.3}/README.md +0 -0
- {embedkit-0.1.1 → embedkit-0.1.3}/src/embedkit/config.py +0 -0
- {embedkit-0.1.1 → embedkit-0.1.3}/src/embedkit/models.py +0 -0
- {embedkit-0.1.1 → embedkit-0.1.3}/src/embedkit/providers/__init__.py +0 -0
- {embedkit-0.1.1 → embedkit-0.1.3}/tests/conftest.py +0 -0
- {embedkit-0.1.1 → embedkit-0.1.3}/tests/fixtures/2407.01449v6_p1.pdf +0 -0
- {embedkit-0.1.1 → embedkit-0.1.3}/tests/fixtures/2407.01449v6_p1.png +0 -0
- {embedkit-0.1.1 → embedkit-0.1.3}/tests/test_embedkit.py +0 -0
- {embedkit-0.1.1 → embedkit-0.1.3}/uv.lock +0 -0
@@ -1,7 +1,6 @@
|
|
1
1
|
# ./main.py
|
2
2
|
from embedkit import EmbedKit
|
3
|
-
from embedkit.
|
4
|
-
from embedkit.providers.cohere import CohereInputType
|
3
|
+
from embedkit.classes import Model, CohereInputType
|
5
4
|
from pathlib import Path
|
6
5
|
import os
|
7
6
|
|
@@ -33,8 +32,9 @@ def get_sample_image() -> Path:
|
|
33
32
|
sample_image = get_sample_image()
|
34
33
|
|
35
34
|
sample_pdf = Path("tests/fixtures/2407.01449v6_p1.pdf")
|
35
|
+
long_pdf = Path("tmp/2407.01449v6.pdf")
|
36
36
|
|
37
|
-
kit = EmbedKit.colpali(model=Model.ColPali.V1_3)
|
37
|
+
kit = EmbedKit.colpali(model=Model.ColPali.V1_3, text_batch_size=16, image_batch_size=8)
|
38
38
|
|
39
39
|
results = kit.embed_text("Hello world")
|
40
40
|
assert results.shape[0] == 1
|
@@ -50,10 +50,17 @@ assert results.shape[0] == 1
|
|
50
50
|
assert len(results.shape) == 3
|
51
51
|
assert len(results.source_images_b64) > 0
|
52
52
|
|
53
|
+
results = kit.embed_pdf(long_pdf)
|
54
|
+
assert results.shape[0] == 26
|
55
|
+
assert len(results.shape) == 3
|
56
|
+
assert len(results.source_images_b64) > 0
|
57
|
+
|
53
58
|
|
54
59
|
kit = EmbedKit.cohere(
|
55
60
|
model=Model.Cohere.EMBED_V4_0,
|
56
61
|
api_key=os.getenv("COHERE_API_KEY"),
|
62
|
+
text_batch_size=64,
|
63
|
+
image_batch_size=8,
|
57
64
|
text_input_type=CohereInputType.SEARCH_QUERY,
|
58
65
|
)
|
59
66
|
|
@@ -64,6 +71,8 @@ assert len(results.shape) == 2
|
|
64
71
|
kit = EmbedKit.cohere(
|
65
72
|
model=Model.Cohere.EMBED_V4_0,
|
66
73
|
api_key=os.getenv("COHERE_API_KEY"),
|
74
|
+
text_batch_size=64,
|
75
|
+
image_batch_size=8,
|
67
76
|
text_input_type=CohereInputType.SEARCH_DOCUMENT,
|
68
77
|
)
|
69
78
|
|
@@ -80,3 +89,8 @@ results = kit.embed_pdf(sample_pdf)
|
|
80
89
|
assert results.shape[0] == 1
|
81
90
|
assert len(results.shape) == 2
|
82
91
|
assert len(results.source_images_b64) > 0
|
92
|
+
|
93
|
+
results = kit.embed_pdf(long_pdf)
|
94
|
+
assert results.shape[0] == 26
|
95
|
+
assert len(results.shape) == 2
|
96
|
+
assert len(results.source_images_b64) > 0
|
@@ -26,21 +26,33 @@ class EmbedKit:
|
|
26
26
|
self._provider = provider_instance
|
27
27
|
|
28
28
|
@classmethod
|
29
|
-
def colpali(
|
29
|
+
def colpali(
|
30
|
+
cls,
|
31
|
+
model: Model = Model.ColPali.V1_3,
|
32
|
+
device: Optional[str] = None,
|
33
|
+
text_batch_size: int = 32,
|
34
|
+
image_batch_size: int = 8,
|
35
|
+
):
|
30
36
|
"""
|
31
37
|
Create EmbedKit instance with ColPali provider.
|
32
38
|
|
33
39
|
Args:
|
34
40
|
model: ColPali model enum
|
35
41
|
device: Device to run on ('cuda', 'mps', 'cpu', or None for auto-detect)
|
42
|
+
text_batch_size: Batch size for text embedding generation
|
43
|
+
image_batch_size: Batch size for image embedding generation
|
36
44
|
"""
|
37
45
|
if model == Model.ColPali.V1_3:
|
38
46
|
model_name = "vidore/colpali-v1.3"
|
39
47
|
else:
|
40
48
|
raise ValueError(f"Unsupported model: {model}")
|
41
49
|
|
42
|
-
|
43
|
-
|
50
|
+
provider = ColPaliProvider(
|
51
|
+
model_name=model_name,
|
52
|
+
device=device,
|
53
|
+
text_batch_size=text_batch_size,
|
54
|
+
image_batch_size=image_batch_size,
|
55
|
+
)
|
44
56
|
return cls(provider)
|
45
57
|
|
46
58
|
@classmethod
|
@@ -48,6 +60,8 @@ class EmbedKit:
|
|
48
60
|
cls,
|
49
61
|
api_key: str,
|
50
62
|
model: Model = Model.Cohere.EMBED_V4_0,
|
63
|
+
text_batch_size: int = 32,
|
64
|
+
image_batch_size: int = 8,
|
51
65
|
text_input_type: CohereInputType = CohereInputType.SEARCH_DOCUMENT,
|
52
66
|
):
|
53
67
|
"""
|
@@ -56,6 +70,8 @@ class EmbedKit:
|
|
56
70
|
Args:
|
57
71
|
api_key: Cohere API key
|
58
72
|
model: Cohere model enum
|
73
|
+
text_batch_size: Batch size for text embedding generation
|
74
|
+
image_batch_size: Batch size for image embedding generation
|
59
75
|
input_type: Type of input for embedding (search_document or search_query)
|
60
76
|
"""
|
61
77
|
if not api_key:
|
@@ -67,7 +83,10 @@ class EmbedKit:
|
|
67
83
|
raise ValueError(f"Unsupported model: {model}")
|
68
84
|
|
69
85
|
provider = CohereProvider(
|
70
|
-
api_key=api_key, model_name=model_name,
|
86
|
+
api_key=api_key, model_name=model_name,
|
87
|
+
text_batch_size=48,
|
88
|
+
image_batch_size=8,
|
89
|
+
text_input_type=text_input_type
|
71
90
|
)
|
72
91
|
return cls(provider)
|
73
92
|
|
@@ -37,9 +37,7 @@ class EmbeddingProvider(ABC):
|
|
37
37
|
pass
|
38
38
|
|
39
39
|
@abstractmethod
|
40
|
-
def embed_pdf(
|
41
|
-
self, pdf: Union[Path, str]
|
42
|
-
) -> EmbeddingResult:
|
40
|
+
def embed_pdf(self, pdf: Union[Path, str]) -> EmbeddingResult:
|
43
41
|
"""Generate image embeddings from PDFsusing the configured provider. Takes a single PDF file."""
|
44
42
|
pass
|
45
43
|
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# ./src/embedkit/classes.py
|
2
|
+
|
3
|
+
"""Core types and enums for the EmbedKit library.
|
4
|
+
|
5
|
+
This module provides the main types and enums that users should interact with:
|
6
|
+
- EmbeddingResult: The result type returned by embedding operations
|
7
|
+
- EmbeddingError: Exception type for embedding operations
|
8
|
+
- Model: Enum of supported embedding models
|
9
|
+
- CohereInputType: Enum for Cohere's input types
|
10
|
+
"""
|
11
|
+
|
12
|
+
from . import EmbeddingResult, EmbeddingError
|
13
|
+
from .models import Model
|
14
|
+
from .providers.cohere import CohereInputType
|
15
|
+
|
16
|
+
__all__ = [
|
17
|
+
"EmbeddingResult",
|
18
|
+
"EmbeddingError",
|
19
|
+
"Model",
|
20
|
+
"CohereInputType"
|
21
|
+
]
|
@@ -24,10 +24,14 @@ class CohereProvider(EmbeddingProvider):
|
|
24
24
|
self,
|
25
25
|
api_key: str,
|
26
26
|
model_name: str,
|
27
|
+
text_batch_size: int,
|
28
|
+
image_batch_size: int,
|
27
29
|
text_input_type: CohereInputType = CohereInputType.SEARCH_DOCUMENT,
|
28
30
|
):
|
29
31
|
self.api_key = api_key
|
30
32
|
self.model_name = model_name
|
33
|
+
self.text_batch_size = text_batch_size
|
34
|
+
self.image_batch_size = image_batch_size
|
31
35
|
self.input_type = text_input_type
|
32
36
|
self._client = None
|
33
37
|
self.provider_name = "Cohere"
|
@@ -55,15 +59,21 @@ class CohereProvider(EmbeddingProvider):
|
|
55
59
|
texts = [texts]
|
56
60
|
|
57
61
|
try:
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
62
|
+
all_embeddings = []
|
63
|
+
|
64
|
+
# Process texts in batches
|
65
|
+
for i in range(0, len(texts), self.text_batch_size):
|
66
|
+
batch_texts = texts[i : i + self.text_batch_size]
|
67
|
+
response = client.embed(
|
68
|
+
texts=batch_texts,
|
69
|
+
model=self.model_name,
|
70
|
+
input_type=self.input_type.value,
|
71
|
+
embedding_types=["float"],
|
72
|
+
)
|
73
|
+
all_embeddings.extend(response.embeddings.float_)
|
64
74
|
|
65
75
|
return EmbeddingResult(
|
66
|
-
embeddings=np.array(
|
76
|
+
embeddings=np.array(all_embeddings),
|
67
77
|
model_name=self.model_name,
|
68
78
|
model_provider=self.provider_name,
|
69
79
|
input_type=self.input_type.value,
|
@@ -81,34 +91,45 @@ class CohereProvider(EmbeddingProvider):
|
|
81
91
|
input_type = "image"
|
82
92
|
|
83
93
|
if isinstance(images, (str, Path)):
|
84
|
-
images = [images]
|
94
|
+
images = [Path(images)]
|
95
|
+
else:
|
96
|
+
images = [Path(img) for img in images]
|
85
97
|
|
86
98
|
try:
|
87
|
-
|
88
|
-
|
89
|
-
b64_image = image_to_base64(image)
|
99
|
+
all_embeddings = []
|
100
|
+
all_b64_images = []
|
90
101
|
|
91
|
-
|
102
|
+
# Process images in batches
|
103
|
+
for i in range(0, len(images), self.image_batch_size):
|
104
|
+
batch_images = images[i : i + self.image_batch_size]
|
105
|
+
b64_images = []
|
92
106
|
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
107
|
+
for image in batch_images:
|
108
|
+
if not image.exists():
|
109
|
+
raise EmbeddingError(f"Image not found: {image}")
|
110
|
+
b64_images.append(image_to_base64(image))
|
111
|
+
|
112
|
+
response = client.embed(
|
113
|
+
model=self.model_name,
|
114
|
+
input_type="image",
|
115
|
+
images=b64_images,
|
116
|
+
embedding_types=["float"],
|
117
|
+
)
|
118
|
+
|
119
|
+
all_embeddings.extend(response.embeddings.float_)
|
120
|
+
all_b64_images.extend(b64_images)
|
99
121
|
|
100
122
|
return EmbeddingResult(
|
101
|
-
embeddings=np.array(
|
123
|
+
embeddings=np.array(all_embeddings),
|
102
124
|
model_name=self.model_name,
|
103
125
|
model_provider=self.provider_name,
|
104
126
|
input_type=input_type,
|
105
|
-
source_images_b64=
|
127
|
+
source_images_b64=all_b64_images,
|
106
128
|
)
|
107
129
|
|
108
130
|
except Exception as e:
|
109
131
|
raise EmbeddingError(f"Failed to embed image with Cohere: {e}") from e
|
110
132
|
|
111
|
-
|
112
133
|
def embed_pdf(self, pdf_path: Path) -> EmbeddingResult:
|
113
134
|
"""Generate embeddings for a PDF file using Cohere API."""
|
114
135
|
image_paths = pdf_to_images(pdf_path)
|
@@ -17,9 +17,17 @@ logger = logging.getLogger(__name__)
|
|
17
17
|
class ColPaliProvider(EmbeddingProvider):
|
18
18
|
"""ColPali embedding provider for document understanding."""
|
19
19
|
|
20
|
-
def __init__(
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
model_name: str,
|
23
|
+
text_batch_size: int,
|
24
|
+
image_batch_size: int,
|
25
|
+
device: Optional[str] = None,
|
26
|
+
):
|
21
27
|
self.model_name = model_name
|
22
28
|
self.provider_name = "ColPali"
|
29
|
+
self.text_batch_size = text_batch_size
|
30
|
+
self.image_batch_size = image_batch_size
|
23
31
|
|
24
32
|
# Auto-detect device
|
25
33
|
if device is None:
|
@@ -64,13 +72,22 @@ class ColPaliProvider(EmbeddingProvider):
|
|
64
72
|
texts = [texts]
|
65
73
|
|
66
74
|
try:
|
67
|
-
|
75
|
+
# Process texts in batches
|
76
|
+
all_embeddings = []
|
68
77
|
|
69
|
-
|
70
|
-
|
78
|
+
for i in range(0, len(texts), self.text_batch_size):
|
79
|
+
batch_texts = texts[i : i + self.text_batch_size]
|
80
|
+
processed = self._processor.process_queries(batch_texts).to(self.device)
|
81
|
+
|
82
|
+
with torch.no_grad():
|
83
|
+
batch_embeddings = self._model(**processed)
|
84
|
+
all_embeddings.append(batch_embeddings.cpu().float().numpy())
|
85
|
+
|
86
|
+
# Concatenate all batch embeddings
|
87
|
+
final_embeddings = np.concatenate(all_embeddings, axis=0)
|
71
88
|
|
72
89
|
return EmbeddingResult(
|
73
|
-
embeddings=
|
90
|
+
embeddings=final_embeddings,
|
74
91
|
model_name=self.model_name,
|
75
92
|
model_provider=self.provider_name,
|
76
93
|
input_type="text",
|
@@ -91,38 +108,44 @@ class ColPaliProvider(EmbeddingProvider):
|
|
91
108
|
images = [Path(img) for img in images]
|
92
109
|
|
93
110
|
try:
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
if not img_path.exists():
|
98
|
-
raise EmbeddingError(f"Image not found: {img_path}")
|
111
|
+
# Process images in batches
|
112
|
+
all_embeddings = []
|
113
|
+
all_b64_images = []
|
99
114
|
|
100
|
-
|
101
|
-
|
115
|
+
for i in range(0, len(images), self.image_batch_size):
|
116
|
+
batch_images = images[i : i + self.image_batch_size]
|
117
|
+
pil_images = []
|
118
|
+
b64_images = []
|
102
119
|
|
103
|
-
for
|
104
|
-
|
120
|
+
for img_path in batch_images:
|
121
|
+
if not img_path.exists():
|
122
|
+
raise EmbeddingError(f"Image not found: {img_path}")
|
105
123
|
|
106
|
-
|
124
|
+
with Image.open(img_path) as img:
|
125
|
+
pil_images.append(img.convert("RGB"))
|
126
|
+
b64_images.append(image_to_base64(img_path))
|
107
127
|
|
108
|
-
|
128
|
+
processed = self._processor.process_images(pil_images).to(self.device)
|
109
129
|
|
130
|
+
with torch.no_grad():
|
131
|
+
batch_embeddings = self._model(**processed)
|
132
|
+
all_embeddings.append(batch_embeddings.cpu().float().numpy())
|
133
|
+
all_b64_images.extend(b64_images)
|
110
134
|
|
111
|
-
|
112
|
-
|
135
|
+
# Concatenate all batch embeddings
|
136
|
+
final_embeddings = np.concatenate(all_embeddings, axis=0)
|
113
137
|
|
114
138
|
return EmbeddingResult(
|
115
|
-
embeddings=
|
139
|
+
embeddings=final_embeddings,
|
116
140
|
model_name=self.model_name,
|
117
141
|
model_provider=self.provider_name,
|
118
142
|
input_type="image",
|
119
|
-
source_images_b64=
|
143
|
+
source_images_b64=all_b64_images,
|
120
144
|
)
|
121
145
|
|
122
146
|
except Exception as e:
|
123
147
|
raise EmbeddingError(f"Failed to embed images: {e}") from e
|
124
148
|
|
125
|
-
|
126
149
|
def embed_pdf(self, pdf_path: Path) -> EmbeddingResult:
|
127
150
|
"""Generate embeddings for a PDF file using ColPali API."""
|
128
151
|
images = pdf_to_images(pdf_path)
|
@@ -26,13 +26,9 @@ def image_to_base64(image_path: Union[str, Path]):
|
|
26
26
|
import base64
|
27
27
|
|
28
28
|
try:
|
29
|
-
base64_only = base64.b64encode(Path(image_path).read_bytes()).decode(
|
30
|
-
"utf-8"
|
31
|
-
)
|
29
|
+
base64_only = base64.b64encode(Path(image_path).read_bytes()).decode("utf-8")
|
32
30
|
except Exception as e:
|
33
|
-
raise ValueError(
|
34
|
-
f"Failed to read image {image_path}: {e}"
|
35
|
-
) from e
|
31
|
+
raise ValueError(f"Failed to read image {image_path}: {e}") from e
|
36
32
|
|
37
33
|
if isinstance(image_path, Path):
|
38
34
|
image_path_str = str(image_path)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|