embedkit 0.1.4__tar.gz → 0.1.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {embedkit-0.1.4 → embedkit-0.1.6}/PKG-INFO +18 -12
- {embedkit-0.1.4 → embedkit-0.1.6}/README.md +16 -10
- {embedkit-0.1.4 → embedkit-0.1.6}/main.py +36 -28
- {embedkit-0.1.4 → embedkit-0.1.6}/pyproject.toml +2 -2
- {embedkit-0.1.4 → embedkit-0.1.6}/src/embedkit/__init__.py +10 -12
- embedkit-0.1.6/src/embedkit/base.py +122 -0
- {embedkit-0.1.4 → embedkit-0.1.6}/src/embedkit/classes.py +1 -6
- embedkit-0.1.6/src/embedkit/models.py +18 -0
- {embedkit-0.1.4 → embedkit-0.1.6}/src/embedkit/providers/cohere.py +31 -44
- embedkit-0.1.6/src/embedkit/providers/colpali.py +162 -0
- embedkit-0.1.6/src/embedkit/utils.py +142 -0
- embedkit-0.1.6/tests/fixtures/2407.01449v6_p1_p5.pdf +0 -0
- {embedkit-0.1.4 → embedkit-0.1.6}/tests/test_embedkit.py +2 -2
- embedkit-0.1.6/tests/test_utils.py +52 -0
- {embedkit-0.1.4 → embedkit-0.1.6}/uv.lock +2 -2
- embedkit-0.1.4/src/embedkit/base.py +0 -53
- embedkit-0.1.4/src/embedkit/models.py +0 -12
- embedkit-0.1.4/src/embedkit/providers/colpali.py +0 -160
- embedkit-0.1.4/src/embedkit/utils.py +0 -48
- {embedkit-0.1.4 → embedkit-0.1.6}/.gitignore +0 -0
- {embedkit-0.1.4 → embedkit-0.1.6}/.python-version +0 -0
- {embedkit-0.1.4 → embedkit-0.1.6}/LICENSE +0 -0
- {embedkit-0.1.4 → embedkit-0.1.6}/src/embedkit/config.py +0 -0
- {embedkit-0.1.4 → embedkit-0.1.6}/src/embedkit/providers/__init__.py +0 -0
- {embedkit-0.1.4 → embedkit-0.1.6}/tests/conftest.py +0 -0
- {embedkit-0.1.4 → embedkit-0.1.6}/tests/fixtures/2407.01449v6_p1.pdf +0 -0
- {embedkit-0.1.4 → embedkit-0.1.6}/tests/fixtures/2407.01449v6_p1.png +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: embedkit
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.6
|
4
4
|
Summary: A simple toolkit for generating vector embeddings across multiple providers and models
|
5
5
|
Author-email: JP Hwang <me@jphwang.com>
|
6
6
|
License: MIT
|
@@ -22,7 +22,7 @@ Requires-Dist: colpali-engine<0.4.0,>=0.3.0
|
|
22
22
|
Requires-Dist: pdf2image>=1.17.0
|
23
23
|
Requires-Dist: pillow>=11.2.1
|
24
24
|
Requires-Dist: torch<=2.5
|
25
|
-
Requires-Dist: transformers
|
25
|
+
Requires-Dist: transformers>=4.46.2
|
26
26
|
Description-Content-Type: text/markdown
|
27
27
|
|
28
28
|
# EmbedKit
|
@@ -45,7 +45,7 @@ from embedkit.classes import Model, CohereInputType
|
|
45
45
|
|
46
46
|
# Initialize with ColPali
|
47
47
|
kit = EmbedKit.colpali(
|
48
|
-
model=Model.ColPali.
|
48
|
+
model=Model.ColPali.COLPALI_V1_3, # or COLSMOL_256M, COLSMOL_500M
|
49
49
|
text_batch_size=16, # Optional: process text in batches of 16
|
50
50
|
image_batch_size=8, # Optional: process images in batches of 8
|
51
51
|
)
|
@@ -54,7 +54,7 @@ kit = EmbedKit.colpali(
|
|
54
54
|
result = kit.embed_text("Hello world")
|
55
55
|
print(result.model_provider)
|
56
56
|
print(result.input_type)
|
57
|
-
print(result.objects[0].embedding.shape)
|
57
|
+
print(result.objects[0].embedding.shape) # Returns 2D array for ColPali
|
58
58
|
print(result.objects[0].source_b64)
|
59
59
|
|
60
60
|
# Initialize with Cohere
|
@@ -70,7 +70,7 @@ kit = EmbedKit.cohere(
|
|
70
70
|
result = kit.embed_text("Hello world")
|
71
71
|
print(result.model_provider)
|
72
72
|
print(result.input_type)
|
73
|
-
print(result.objects[0].embedding.shape)
|
73
|
+
print(result.objects[0].embedding.shape) # Returns 1D array for Cohere
|
74
74
|
print(result.objects[0].source_b64)
|
75
75
|
```
|
76
76
|
|
@@ -85,8 +85,8 @@ result = kit.embed_image(image_path)
|
|
85
85
|
|
86
86
|
print(result.model_provider)
|
87
87
|
print(result.input_type)
|
88
|
-
print(result.objects[0].embedding.shape)
|
89
|
-
print(result.objects[0].source_b64)
|
88
|
+
print(result.objects[0].embedding.shape) # 2D for ColPali, 1D for Cohere
|
89
|
+
print(result.objects[0].source_b64) # Base64 encoded image
|
90
90
|
```
|
91
91
|
|
92
92
|
### PDF Embeddings
|
@@ -100,8 +100,8 @@ result = kit.embed_pdf(pdf_path)
|
|
100
100
|
|
101
101
|
print(result.model_provider)
|
102
102
|
print(result.input_type)
|
103
|
-
print(result.objects[0].embedding.shape)
|
104
|
-
print(result.objects[0].source_b64)
|
103
|
+
print(result.objects[0].embedding.shape) # 2D for ColPali, 1D for Cohere
|
104
|
+
print(result.objects[0].source_b64) # Base64 encoded PDF page
|
105
105
|
```
|
106
106
|
|
107
107
|
## Response Format
|
@@ -116,17 +116,23 @@ class EmbeddingResponse:
|
|
116
116
|
objects: List[EmbeddingObject]
|
117
117
|
|
118
118
|
class EmbeddingObject:
|
119
|
-
embedding: np.ndarray
|
120
|
-
source_b64: Optional[str]
|
119
|
+
embedding: np.ndarray # 1D array for Cohere, 2D array for ColPali
|
120
|
+
source_b64: Optional[str] # Base64 encoded source for images and PDFs
|
121
121
|
```
|
122
122
|
|
123
123
|
## Supported Models
|
124
124
|
|
125
125
|
### ColPali
|
126
|
-
- `Model.ColPali.
|
126
|
+
- `Model.ColPali.COLPALI_V1_3`
|
127
|
+
- `Model.ColPali.COLSMOL_256M`
|
128
|
+
- `Model.ColPali.COLSMOL_500M`
|
127
129
|
|
128
130
|
### Cohere
|
129
131
|
- `Model.Cohere.EMBED_V4_0`
|
132
|
+
- `Model.Cohere.EMBED_ENGLISH_V3_0`
|
133
|
+
- `Model.Cohere.EMBED_ENGLISH_LIGHT_V3_0`
|
134
|
+
- `Model.Cohere.EMBED_MULTILINGUAL_V3_0`
|
135
|
+
- `Model.Cohere.EMBED_MULTILINGUAL_LIGHT_V3_0`
|
130
136
|
|
131
137
|
## Requirements
|
132
138
|
|
@@ -18,7 +18,7 @@ from embedkit.classes import Model, CohereInputType
|
|
18
18
|
|
19
19
|
# Initialize with ColPali
|
20
20
|
kit = EmbedKit.colpali(
|
21
|
-
model=Model.ColPali.
|
21
|
+
model=Model.ColPali.COLPALI_V1_3, # or COLSMOL_256M, COLSMOL_500M
|
22
22
|
text_batch_size=16, # Optional: process text in batches of 16
|
23
23
|
image_batch_size=8, # Optional: process images in batches of 8
|
24
24
|
)
|
@@ -27,7 +27,7 @@ kit = EmbedKit.colpali(
|
|
27
27
|
result = kit.embed_text("Hello world")
|
28
28
|
print(result.model_provider)
|
29
29
|
print(result.input_type)
|
30
|
-
print(result.objects[0].embedding.shape)
|
30
|
+
print(result.objects[0].embedding.shape) # Returns 2D array for ColPali
|
31
31
|
print(result.objects[0].source_b64)
|
32
32
|
|
33
33
|
# Initialize with Cohere
|
@@ -43,7 +43,7 @@ kit = EmbedKit.cohere(
|
|
43
43
|
result = kit.embed_text("Hello world")
|
44
44
|
print(result.model_provider)
|
45
45
|
print(result.input_type)
|
46
|
-
print(result.objects[0].embedding.shape)
|
46
|
+
print(result.objects[0].embedding.shape) # Returns 1D array for Cohere
|
47
47
|
print(result.objects[0].source_b64)
|
48
48
|
```
|
49
49
|
|
@@ -58,8 +58,8 @@ result = kit.embed_image(image_path)
|
|
58
58
|
|
59
59
|
print(result.model_provider)
|
60
60
|
print(result.input_type)
|
61
|
-
print(result.objects[0].embedding.shape)
|
62
|
-
print(result.objects[0].source_b64)
|
61
|
+
print(result.objects[0].embedding.shape) # 2D for ColPali, 1D for Cohere
|
62
|
+
print(result.objects[0].source_b64) # Base64 encoded image
|
63
63
|
```
|
64
64
|
|
65
65
|
### PDF Embeddings
|
@@ -73,8 +73,8 @@ result = kit.embed_pdf(pdf_path)
|
|
73
73
|
|
74
74
|
print(result.model_provider)
|
75
75
|
print(result.input_type)
|
76
|
-
print(result.objects[0].embedding.shape)
|
77
|
-
print(result.objects[0].source_b64)
|
76
|
+
print(result.objects[0].embedding.shape) # 2D for ColPali, 1D for Cohere
|
77
|
+
print(result.objects[0].source_b64) # Base64 encoded PDF page
|
78
78
|
```
|
79
79
|
|
80
80
|
## Response Format
|
@@ -89,17 +89,23 @@ class EmbeddingResponse:
|
|
89
89
|
objects: List[EmbeddingObject]
|
90
90
|
|
91
91
|
class EmbeddingObject:
|
92
|
-
embedding: np.ndarray
|
93
|
-
source_b64: Optional[str]
|
92
|
+
embedding: np.ndarray # 1D array for Cohere, 2D array for ColPali
|
93
|
+
source_b64: Optional[str] # Base64 encoded source for images and PDFs
|
94
94
|
```
|
95
95
|
|
96
96
|
## Supported Models
|
97
97
|
|
98
98
|
### ColPali
|
99
|
-
- `Model.ColPali.
|
99
|
+
- `Model.ColPali.COLPALI_V1_3`
|
100
|
+
- `Model.ColPali.COLSMOL_256M`
|
101
|
+
- `Model.ColPali.COLSMOL_500M`
|
100
102
|
|
101
103
|
### Cohere
|
102
104
|
- `Model.Cohere.EMBED_V4_0`
|
105
|
+
- `Model.Cohere.EMBED_ENGLISH_V3_0`
|
106
|
+
- `Model.Cohere.EMBED_ENGLISH_LIGHT_V3_0`
|
107
|
+
- `Model.Cohere.EMBED_MULTILINGUAL_V3_0`
|
108
|
+
- `Model.Cohere.EMBED_MULTILINGUAL_LIGHT_V3_0`
|
103
109
|
|
104
110
|
## Requirements
|
105
111
|
|
@@ -32,30 +32,7 @@ def get_sample_image() -> Path:
|
|
32
32
|
sample_image = get_sample_image()
|
33
33
|
|
34
34
|
sample_pdf = Path("tests/fixtures/2407.01449v6_p1.pdf")
|
35
|
-
|
36
|
-
|
37
|
-
kit = EmbedKit.colpali(model=Model.ColPali.V1_3, text_batch_size=16, image_batch_size=8)
|
38
|
-
|
39
|
-
results = kit.embed_text("Hello world")
|
40
|
-
assert len(results.objects) == 1
|
41
|
-
assert len(results.objects[0].embedding.shape) == 2
|
42
|
-
assert results.objects[0].source_b64 == None
|
43
|
-
|
44
|
-
results = kit.embed_image(sample_image)
|
45
|
-
assert len(results.objects) == 1
|
46
|
-
assert len(results.objects[0].embedding.shape) == 2
|
47
|
-
assert type(results.objects[0].source_b64) == str
|
48
|
-
|
49
|
-
results = kit.embed_pdf(sample_pdf)
|
50
|
-
assert len(results.objects) == 1
|
51
|
-
assert len(results.objects[0].embedding.shape) == 2
|
52
|
-
assert type(results.objects[0].source_b64) == str
|
53
|
-
|
54
|
-
# results = kit.embed_pdf(long_pdf)
|
55
|
-
# assert len(results.objects) == 26
|
56
|
-
# assert len(results.objects[0].embedding.shape) == 2
|
57
|
-
# assert type(results.objects[0].source_b64) == str
|
58
|
-
|
35
|
+
longer_pdf = Path("tests/fixtures/2407.01449v6_p1_p5.pdf")
|
59
36
|
|
60
37
|
kit = EmbedKit.cohere(
|
61
38
|
model=Model.Cohere.EMBED_V4_0,
|
@@ -65,6 +42,7 @@ kit = EmbedKit.cohere(
|
|
65
42
|
text_input_type=CohereInputType.SEARCH_QUERY,
|
66
43
|
)
|
67
44
|
|
45
|
+
print(f"Trying out Cohere")
|
68
46
|
results = kit.embed_text("Hello world")
|
69
47
|
assert len(results.objects) == 1
|
70
48
|
assert len(results.objects[0].embedding.shape) == 1
|
@@ -93,7 +71,37 @@ assert len(results.objects) == 1
|
|
93
71
|
assert len(results.objects[0].embedding.shape) == 1
|
94
72
|
assert type(results.objects[0].source_b64) == str
|
95
73
|
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
74
|
+
results = kit.embed_pdf(longer_pdf)
|
75
|
+
assert len(results.objects) == 5
|
76
|
+
assert len(results.objects[0].embedding.shape) == 1
|
77
|
+
assert type(results.objects[0].source_b64) == str
|
78
|
+
|
79
|
+
for colpali_model in [
|
80
|
+
Model.ColPali.COLSMOL_256M,
|
81
|
+
Model.ColPali.COLSMOL_500M,
|
82
|
+
Model.ColPali.COLPALI_V1_3,
|
83
|
+
]:
|
84
|
+
print(f"Trying out {colpali_model}")
|
85
|
+
kit = EmbedKit.colpali(
|
86
|
+
model=colpali_model, text_batch_size=16, image_batch_size=8
|
87
|
+
)
|
88
|
+
|
89
|
+
results = kit.embed_text("Hello world")
|
90
|
+
assert len(results.objects) == 1
|
91
|
+
assert len(results.objects[0].embedding.shape) == 2
|
92
|
+
assert results.objects[0].source_b64 == None
|
93
|
+
|
94
|
+
results = kit.embed_image(sample_image)
|
95
|
+
assert len(results.objects) == 1
|
96
|
+
assert len(results.objects[0].embedding.shape) == 2
|
97
|
+
assert type(results.objects[0].source_b64) == str
|
98
|
+
|
99
|
+
results = kit.embed_pdf(sample_pdf)
|
100
|
+
assert len(results.objects) == 1
|
101
|
+
assert len(results.objects[0].embedding.shape) == 2
|
102
|
+
assert type(results.objects[0].source_b64) == str
|
103
|
+
|
104
|
+
results = kit.embed_pdf(longer_pdf)
|
105
|
+
assert len(results.objects) == 5
|
106
|
+
assert len(results.objects[0].embedding.shape) == 2
|
107
|
+
assert type(results.objects[0].source_b64) == str
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "embedkit"
|
3
|
-
version = "0.1.
|
3
|
+
version = "0.1.6"
|
4
4
|
description = "A simple toolkit for generating vector embeddings across multiple providers and models"
|
5
5
|
readme = "README.md"
|
6
6
|
requires-python = ">=3.10"
|
@@ -11,7 +11,7 @@ dependencies = [
|
|
11
11
|
"pdf2image>=1.17.0",
|
12
12
|
"pillow>=11.2.1",
|
13
13
|
"torch<=2.5",
|
14
|
-
"transformers",
|
14
|
+
"transformers>=4.46.2",
|
15
15
|
]
|
16
16
|
authors = [
|
17
17
|
{name = "JP Hwang", email = "me@jphwang.com"},
|
@@ -5,7 +5,6 @@ EmbedKit: A unified toolkit for generating vector embeddings.
|
|
5
5
|
|
6
6
|
from typing import Union, List, Optional
|
7
7
|
from pathlib import Path
|
8
|
-
import numpy as np
|
9
8
|
|
10
9
|
from .models import Model
|
11
10
|
from .base import EmbeddingError, EmbeddingResponse
|
@@ -28,7 +27,7 @@ class EmbedKit:
|
|
28
27
|
@classmethod
|
29
28
|
def colpali(
|
30
29
|
cls,
|
31
|
-
model: Model = Model.ColPali.
|
30
|
+
model: Model = Model.ColPali.COLPALI_V1_3,
|
32
31
|
device: Optional[str] = None,
|
33
32
|
text_batch_size: int = 32,
|
34
33
|
image_batch_size: int = 8,
|
@@ -42,13 +41,13 @@ class EmbedKit:
|
|
42
41
|
text_batch_size: Batch size for text embedding generation
|
43
42
|
image_batch_size: Batch size for image embedding generation
|
44
43
|
"""
|
45
|
-
if model
|
46
|
-
|
47
|
-
|
48
|
-
|
44
|
+
if not isinstance(model, Model.ColPali):
|
45
|
+
raise ValueError(
|
46
|
+
f"Unsupported model: {model}. Must be a Model.ColPali enum value."
|
47
|
+
)
|
49
48
|
|
50
49
|
provider = ColPaliProvider(
|
51
|
-
|
50
|
+
model=model,
|
52
51
|
device=device,
|
53
52
|
text_batch_size=text_batch_size,
|
54
53
|
image_batch_size=image_batch_size,
|
@@ -77,16 +76,15 @@ class EmbedKit:
|
|
77
76
|
if not api_key:
|
78
77
|
raise ValueError("API key is required")
|
79
78
|
|
80
|
-
if model
|
81
|
-
model_name = "embed-v4.0"
|
82
|
-
else:
|
79
|
+
if not isinstance(model, Model.Cohere):
|
83
80
|
raise ValueError(f"Unsupported model: {model}")
|
84
81
|
|
85
82
|
provider = CohereProvider(
|
86
|
-
api_key=api_key,
|
83
|
+
api_key=api_key,
|
84
|
+
model=model,
|
87
85
|
text_batch_size=text_batch_size,
|
88
86
|
image_batch_size=image_batch_size,
|
89
|
-
text_input_type=text_input_type
|
87
|
+
text_input_type=text_input_type,
|
90
88
|
)
|
91
89
|
return cls(provider)
|
92
90
|
|
@@ -0,0 +1,122 @@
|
|
1
|
+
# ./src/embedkit/base.py
|
2
|
+
"""Base classes for EmbedKit."""
|
3
|
+
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from typing import Union, List, Optional
|
6
|
+
from pathlib import Path
|
7
|
+
import numpy as np
|
8
|
+
from dataclasses import dataclass
|
9
|
+
|
10
|
+
from .models import Model
|
11
|
+
from .utils import with_pdf_cleanup
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass
|
15
|
+
class EmbeddingObject:
|
16
|
+
embedding: np.ndarray
|
17
|
+
source_b64: str = None
|
18
|
+
source_content_type: str = None # e.g., "image/png", "image/jpeg"
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass
|
22
|
+
class EmbeddingResponse:
|
23
|
+
model_name: str
|
24
|
+
model_provider: str
|
25
|
+
input_type: str
|
26
|
+
objects: List[EmbeddingObject]
|
27
|
+
|
28
|
+
@property
|
29
|
+
def shape(self) -> tuple:
|
30
|
+
return self.objects[0].embedding.shape
|
31
|
+
|
32
|
+
|
33
|
+
class EmbeddingProvider(ABC):
|
34
|
+
"""Abstract base class for embedding providers."""
|
35
|
+
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
model_name: str,
|
39
|
+
text_batch_size: int,
|
40
|
+
image_batch_size: int,
|
41
|
+
provider_name: str,
|
42
|
+
):
|
43
|
+
self.model_name = model_name
|
44
|
+
self.provider_name = provider_name
|
45
|
+
self.text_batch_size = text_batch_size
|
46
|
+
self.image_batch_size = image_batch_size
|
47
|
+
|
48
|
+
def _normalize_text_input(self, texts: Union[str, List[str]]) -> List[str]:
|
49
|
+
"""Normalize text input to a list of strings."""
|
50
|
+
if isinstance(texts, str):
|
51
|
+
return [texts]
|
52
|
+
return texts
|
53
|
+
|
54
|
+
def _normalize_image_input(
|
55
|
+
self, images: Union[Path, str, List[Union[Path, str]]]
|
56
|
+
) -> List[Path]:
|
57
|
+
"""Normalize image input to a list of Path objects."""
|
58
|
+
if isinstance(images, (str, Path)):
|
59
|
+
return [Path(images)]
|
60
|
+
return [Path(img) for img in images]
|
61
|
+
|
62
|
+
def _create_text_response(
|
63
|
+
self, embeddings: List[np.ndarray], input_type: str = "text"
|
64
|
+
) -> EmbeddingResponse:
|
65
|
+
"""Create a standardized text embedding response."""
|
66
|
+
return EmbeddingResponse(
|
67
|
+
model_name=self.model_name,
|
68
|
+
model_provider=self.provider_name,
|
69
|
+
input_type=input_type,
|
70
|
+
objects=[EmbeddingObject(embedding=e) for e in embeddings],
|
71
|
+
)
|
72
|
+
|
73
|
+
def _create_image_response(
|
74
|
+
self,
|
75
|
+
embeddings: List[np.ndarray],
|
76
|
+
b64_data: List[str],
|
77
|
+
content_types: List[str],
|
78
|
+
input_type: str = "image",
|
79
|
+
) -> EmbeddingResponse:
|
80
|
+
"""Create a standardized image embedding response."""
|
81
|
+
return EmbeddingResponse(
|
82
|
+
model_name=self.model_name,
|
83
|
+
model_provider=self.provider_name,
|
84
|
+
input_type=input_type,
|
85
|
+
objects=[
|
86
|
+
EmbeddingObject(
|
87
|
+
embedding=embedding,
|
88
|
+
source_b64=b64_data,
|
89
|
+
source_content_type=content_type,
|
90
|
+
)
|
91
|
+
for embedding, b64_data, content_type in zip(
|
92
|
+
embeddings, b64_data, content_types
|
93
|
+
)
|
94
|
+
],
|
95
|
+
)
|
96
|
+
|
97
|
+
@abstractmethod
|
98
|
+
def embed_text(self, texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
|
99
|
+
"""Generate document text embeddings using the configured provider."""
|
100
|
+
pass
|
101
|
+
|
102
|
+
@abstractmethod
|
103
|
+
def embed_image(
|
104
|
+
self, images: Union[Path, str, List[Union[Path, str]]]
|
105
|
+
) -> EmbeddingResponse:
|
106
|
+
"""Generate image embeddings using the configured provider."""
|
107
|
+
pass
|
108
|
+
|
109
|
+
def embed_pdf(self, pdf_path: Path) -> EmbeddingResponse:
|
110
|
+
"""Generate embeddings for a PDF file."""
|
111
|
+
return self._embed_pdf_impl(pdf_path)
|
112
|
+
|
113
|
+
@with_pdf_cleanup
|
114
|
+
def _embed_pdf_impl(self, pdf_path: List[Path]) -> EmbeddingResponse:
|
115
|
+
"""Internal implementation of PDF embedding with cleanup handled by decorator."""
|
116
|
+
return self.embed_image(pdf_path)
|
117
|
+
|
118
|
+
|
119
|
+
class EmbeddingError(Exception):
|
120
|
+
"""Base exception for embedding-related errors."""
|
121
|
+
|
122
|
+
pass
|
@@ -13,9 +13,4 @@ from . import EmbeddingResponse, EmbeddingError
|
|
13
13
|
from .models import Model
|
14
14
|
from .providers.cohere import CohereInputType
|
15
15
|
|
16
|
-
__all__ = [
|
17
|
-
"EmbeddingResponse",
|
18
|
-
"EmbeddingError",
|
19
|
-
"Model",
|
20
|
-
"CohereInputType"
|
21
|
-
]
|
16
|
+
__all__ = ["EmbeddingResponse", "EmbeddingError", "Model", "CohereInputType"]
|
@@ -0,0 +1,18 @@
|
|
1
|
+
# ./src/embedkit/models.py
|
2
|
+
"""Model definitions and enum for EmbedKit."""
|
3
|
+
|
4
|
+
from enum import Enum
|
5
|
+
|
6
|
+
|
7
|
+
class Model:
|
8
|
+
class ColPali(Enum):
|
9
|
+
COLPALI_V1_3 = "vidore/colpali-v1.3"
|
10
|
+
COLSMOL_500M = "vidore/colSmol-500M"
|
11
|
+
COLSMOL_256M = "vidore/colSmol-256M"
|
12
|
+
|
13
|
+
class Cohere(Enum):
|
14
|
+
EMBED_V4_0 = "embed-v4.0"
|
15
|
+
EMBED_ENGLISH_V3_0 = "embed-english-v3.0"
|
16
|
+
EMBED_ENGLISH_LIGHT_V3_0 = "embed-english-light-v3.0"
|
17
|
+
EMBED_MULTILINGUAL_V3_0 = "embed-multilingual-v3.0"
|
18
|
+
EMBED_MULTILINGUAL_LIGHT_V3_0 = "embed-multilingual-light-v3.0"
|
@@ -5,9 +5,13 @@ from typing import Union, List
|
|
5
5
|
from pathlib import Path
|
6
6
|
import numpy as np
|
7
7
|
from enum import Enum
|
8
|
+
import logging
|
8
9
|
|
9
|
-
from ..
|
10
|
-
from ..
|
10
|
+
from ..models import Model
|
11
|
+
from ..utils import image_to_base64
|
12
|
+
from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResponse
|
13
|
+
|
14
|
+
logger = logging.getLogger(__name__)
|
11
15
|
|
12
16
|
|
13
17
|
class CohereInputType(Enum):
|
@@ -23,18 +27,20 @@ class CohereProvider(EmbeddingProvider):
|
|
23
27
|
def __init__(
|
24
28
|
self,
|
25
29
|
api_key: str,
|
26
|
-
|
30
|
+
model: Model.Cohere,
|
27
31
|
text_batch_size: int,
|
28
32
|
image_batch_size: int,
|
29
33
|
text_input_type: CohereInputType = CohereInputType.SEARCH_DOCUMENT,
|
30
34
|
):
|
35
|
+
super().__init__(
|
36
|
+
model_name=model.value,
|
37
|
+
text_batch_size=text_batch_size,
|
38
|
+
image_batch_size=image_batch_size,
|
39
|
+
provider_name="Cohere",
|
40
|
+
)
|
31
41
|
self.api_key = api_key
|
32
|
-
self.model_name = model_name
|
33
|
-
self.text_batch_size = text_batch_size
|
34
|
-
self.image_batch_size = image_batch_size
|
35
42
|
self.input_type = text_input_type
|
36
43
|
self._client = None
|
37
|
-
self.provider_name = "Cohere"
|
38
44
|
|
39
45
|
def _get_client(self):
|
40
46
|
"""Lazy load the Cohere client."""
|
@@ -54,9 +60,7 @@ class CohereProvider(EmbeddingProvider):
|
|
54
60
|
def embed_text(self, texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
|
55
61
|
"""Generate text embeddings using the Cohere API."""
|
56
62
|
client = self._get_client()
|
57
|
-
|
58
|
-
if isinstance(texts, str):
|
59
|
-
texts = [texts]
|
63
|
+
texts = self._normalize_text_input(texts)
|
60
64
|
|
61
65
|
try:
|
62
66
|
all_embeddings = []
|
@@ -72,16 +76,7 @@ class CohereProvider(EmbeddingProvider):
|
|
72
76
|
)
|
73
77
|
all_embeddings.extend(np.array(response.embeddings.float_))
|
74
78
|
|
75
|
-
return
|
76
|
-
model_name=self.model_name,
|
77
|
-
model_provider=self.provider_name,
|
78
|
-
input_type=self.input_type.value,
|
79
|
-
objects=[
|
80
|
-
EmbeddingObject(
|
81
|
-
embedding=e,
|
82
|
-
) for e in all_embeddings
|
83
|
-
]
|
84
|
-
)
|
79
|
+
return self._create_text_response(all_embeddings, self.input_type.value)
|
85
80
|
|
86
81
|
except Exception as e:
|
87
82
|
raise EmbeddingError(f"Failed to embed text with Cohere: {e}") from e
|
@@ -92,12 +87,9 @@ class CohereProvider(EmbeddingProvider):
|
|
92
87
|
) -> EmbeddingResponse:
|
93
88
|
"""Generate embeddings for images using Cohere API."""
|
94
89
|
client = self._get_client()
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
images = [Path(images)]
|
99
|
-
else:
|
100
|
-
images = [Path(img) for img in images]
|
90
|
+
images = self._normalize_image_input(images)
|
91
|
+
total_images = len(images)
|
92
|
+
logger.info(f"Starting to process {total_images} images")
|
101
93
|
|
102
94
|
try:
|
103
95
|
all_embeddings = []
|
@@ -106,12 +98,17 @@ class CohereProvider(EmbeddingProvider):
|
|
106
98
|
# Process images in batches
|
107
99
|
for i in range(0, len(images), self.image_batch_size):
|
108
100
|
batch_images = images[i : i + self.image_batch_size]
|
101
|
+
logger.info(f"Processing batch {i//self.image_batch_size + 1} of {(total_images + self.image_batch_size - 1)//self.image_batch_size} ({len(batch_images)} images)")
|
109
102
|
b64_images = []
|
110
103
|
|
111
104
|
for image in batch_images:
|
112
105
|
if not image.exists():
|
113
106
|
raise EmbeddingError(f"Image not found: {image}")
|
114
|
-
|
107
|
+
b64_data, content_type = image_to_base64(image)
|
108
|
+
# Construct full data URI for API
|
109
|
+
data_uri = f"data:{content_type};base64,{b64_data}"
|
110
|
+
b64_images.append(data_uri)
|
111
|
+
all_b64_images.append((b64_data, content_type))
|
115
112
|
|
116
113
|
response = client.embed(
|
117
114
|
model=self.model_name,
|
@@ -121,24 +118,14 @@ class CohereProvider(EmbeddingProvider):
|
|
121
118
|
)
|
122
119
|
|
123
120
|
all_embeddings.extend(np.array(response.embeddings.float_))
|
124
|
-
|
125
|
-
|
126
|
-
return
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
objects=[
|
131
|
-
EmbeddingObject(
|
132
|
-
embedding=all_embeddings[i],
|
133
|
-
source_b64=all_b64_images[i]
|
134
|
-
) for i in range(len(all_embeddings))
|
135
|
-
]
|
121
|
+
|
122
|
+
logger.info(f"Successfully processed all {total_images} images")
|
123
|
+
return self._create_image_response(
|
124
|
+
all_embeddings,
|
125
|
+
[b64 for b64, _ in all_b64_images],
|
126
|
+
[content_type for _, content_type in all_b64_images],
|
136
127
|
)
|
137
128
|
|
138
129
|
except Exception as e:
|
130
|
+
logger.error(f"Failed to embed images: {e}")
|
139
131
|
raise EmbeddingError(f"Failed to embed image with Cohere: {e}") from e
|
140
|
-
|
141
|
-
def embed_pdf(self, pdf_path: Path) -> EmbeddingResponse:
|
142
|
-
"""Generate embeddings for a PDF file using Cohere API."""
|
143
|
-
image_paths = pdf_to_images(pdf_path)
|
144
|
-
return self.embed_image(image_paths)
|
@@ -0,0 +1,162 @@
|
|
1
|
+
# ./src/embedkit/providers/colpali.py
|
2
|
+
"""ColPali embedding provider."""
|
3
|
+
|
4
|
+
from typing import Union, List, Optional
|
5
|
+
from pathlib import Path
|
6
|
+
import logging
|
7
|
+
import numpy as np
|
8
|
+
import torch
|
9
|
+
from PIL import Image
|
10
|
+
|
11
|
+
from ..models import Model
|
12
|
+
from ..utils import image_to_base64
|
13
|
+
from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResponse
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class ColPaliProvider(EmbeddingProvider):
|
19
|
+
"""ColPali embedding provider for document understanding."""
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
model: Model.ColPali,
|
24
|
+
text_batch_size: int,
|
25
|
+
image_batch_size: int,
|
26
|
+
device: Optional[str] = None,
|
27
|
+
):
|
28
|
+
super().__init__(
|
29
|
+
model_name=model.value,
|
30
|
+
text_batch_size=text_batch_size,
|
31
|
+
image_batch_size=image_batch_size,
|
32
|
+
provider_name="ColPali",
|
33
|
+
)
|
34
|
+
|
35
|
+
# Auto-detect device
|
36
|
+
if device is None:
|
37
|
+
if torch.cuda.is_available():
|
38
|
+
device = "cuda"
|
39
|
+
elif torch.backends.mps.is_available():
|
40
|
+
device = "mps"
|
41
|
+
else:
|
42
|
+
device = "cpu"
|
43
|
+
|
44
|
+
self._hf_device = device
|
45
|
+
self._hf_model = None
|
46
|
+
self._hf_processor = None
|
47
|
+
|
48
|
+
def _load_model(self):
|
49
|
+
"""Lazy load the model."""
|
50
|
+
if self._hf_model is None:
|
51
|
+
try:
|
52
|
+
if self.model_name in [Model.ColPali.COLPALI_V1_3.value]:
|
53
|
+
from colpali_engine.models import ColPali, ColPaliProcessor
|
54
|
+
|
55
|
+
self._hf_model = ColPali.from_pretrained(
|
56
|
+
self.model_name,
|
57
|
+
torch_dtype=torch.bfloat16,
|
58
|
+
device_map=self._hf_device,
|
59
|
+
).eval()
|
60
|
+
|
61
|
+
self._hf_processor = ColPaliProcessor.from_pretrained(self.model_name)
|
62
|
+
|
63
|
+
elif self.model_name in [
|
64
|
+
Model.ColPali.COLSMOL_500M.value,
|
65
|
+
Model.ColPali.COLSMOL_256M.value,
|
66
|
+
]:
|
67
|
+
from colpali_engine.models import ColIdefics3, ColIdefics3Processor
|
68
|
+
|
69
|
+
self._hf_model = ColIdefics3.from_pretrained(
|
70
|
+
self.model_name,
|
71
|
+
torch_dtype=torch.bfloat16,
|
72
|
+
device_map=self._hf_device,
|
73
|
+
).eval()
|
74
|
+
self._hf_processor = ColIdefics3Processor.from_pretrained(
|
75
|
+
self.model_name
|
76
|
+
)
|
77
|
+
else:
|
78
|
+
raise ValueError(f"Unable to load model for: {self.model_name}.")
|
79
|
+
|
80
|
+
logger.info(f"Loaded {self.model_name} on {self._hf_device}")
|
81
|
+
|
82
|
+
except ImportError as e:
|
83
|
+
raise EmbeddingError(
|
84
|
+
"ColPali not installed. Run: pip install colpali-engine"
|
85
|
+
) from e
|
86
|
+
except Exception as e:
|
87
|
+
raise EmbeddingError(f"Failed to load model: {e}") from e
|
88
|
+
|
89
|
+
def embed_text(self, texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
|
90
|
+
"""Generate embeddings for text inputs."""
|
91
|
+
self._load_model()
|
92
|
+
texts = self._normalize_text_input(texts)
|
93
|
+
|
94
|
+
try:
|
95
|
+
# Process texts in batches
|
96
|
+
all_embeddings: List[np.ndarray] = []
|
97
|
+
|
98
|
+
for i in range(0, len(texts), self.text_batch_size):
|
99
|
+
batch_texts = texts[i : i + self.text_batch_size]
|
100
|
+
processed = self._hf_processor.process_queries(batch_texts).to(self._hf_device)
|
101
|
+
|
102
|
+
with torch.no_grad():
|
103
|
+
batch_embeddings = self._hf_model(**processed)
|
104
|
+
all_embeddings.append(batch_embeddings.cpu().float().numpy())
|
105
|
+
|
106
|
+
# Concatenate all batch embeddings
|
107
|
+
final_embeddings = np.concatenate(all_embeddings, axis=0)
|
108
|
+
return self._create_text_response(final_embeddings)
|
109
|
+
|
110
|
+
except Exception as e:
|
111
|
+
raise EmbeddingError(f"Failed to embed text: {e}") from e
|
112
|
+
|
113
|
+
def embed_image(
|
114
|
+
self, images: Union[Path, str, List[Union[Path, str]]]
|
115
|
+
) -> EmbeddingResponse:
|
116
|
+
"""Generate embeddings for images."""
|
117
|
+
self._load_model()
|
118
|
+
images = self._normalize_image_input(images)
|
119
|
+
total_images = len(images)
|
120
|
+
logger.info(f"Starting to process {total_images} images")
|
121
|
+
|
122
|
+
try:
|
123
|
+
# Process images in batches
|
124
|
+
all_embeddings: List[np.ndarray] = []
|
125
|
+
all_b64_data: List[str] = []
|
126
|
+
all_content_types: List[str] = []
|
127
|
+
|
128
|
+
for i in range(0, len(images), self.image_batch_size):
|
129
|
+
batch_images = images[i : i + self.image_batch_size]
|
130
|
+
logger.info(f"Processing batch {i//self.image_batch_size + 1} of {(total_images + self.image_batch_size - 1)//self.image_batch_size} ({len(batch_images)} images)")
|
131
|
+
pil_images = []
|
132
|
+
batch_b64_data = []
|
133
|
+
batch_content_types = []
|
134
|
+
|
135
|
+
for img_path in batch_images:
|
136
|
+
if not img_path.exists():
|
137
|
+
raise EmbeddingError(f"Image not found: {img_path}")
|
138
|
+
|
139
|
+
with Image.open(img_path) as img:
|
140
|
+
pil_images.append(img.convert("RGB"))
|
141
|
+
b64, content_type = image_to_base64(img_path)
|
142
|
+
batch_b64_data.append(b64)
|
143
|
+
batch_content_types.append(content_type)
|
144
|
+
|
145
|
+
processed = self._hf_processor.process_images(pil_images).to(self._hf_device)
|
146
|
+
|
147
|
+
with torch.no_grad():
|
148
|
+
batch_embeddings = self._hf_model(**processed)
|
149
|
+
all_embeddings.append(batch_embeddings.cpu().float().numpy())
|
150
|
+
all_b64_data.extend(batch_b64_data)
|
151
|
+
all_content_types.extend(batch_content_types)
|
152
|
+
|
153
|
+
# Concatenate all batch embeddings
|
154
|
+
final_embeddings = np.concatenate(all_embeddings, axis=0)
|
155
|
+
logger.info(f"Successfully processed all {total_images} images")
|
156
|
+
return self._create_image_response(
|
157
|
+
final_embeddings, all_b64_data, all_content_types
|
158
|
+
)
|
159
|
+
|
160
|
+
except Exception as e:
|
161
|
+
logger.error(f"Failed to embed images: {e}")
|
162
|
+
raise EmbeddingError(f"Failed to embed images: {e}") from e
|
@@ -0,0 +1,142 @@
|
|
1
|
+
import tempfile
|
2
|
+
import shutil
|
3
|
+
import logging
|
4
|
+
from contextlib import contextmanager
|
5
|
+
from pdf2image import convert_from_path
|
6
|
+
from pathlib import Path
|
7
|
+
from .config import get_temp_dir
|
8
|
+
from typing import Union, List, Iterator, Callable, TypeVar, Any
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
T = TypeVar("T")
|
13
|
+
|
14
|
+
|
15
|
+
@contextmanager
|
16
|
+
def temporary_directory() -> Iterator[Path]:
|
17
|
+
"""Create a temporary directory that is automatically cleaned up when done.
|
18
|
+
|
19
|
+
Yields:
|
20
|
+
Path: Path to the temporary directory
|
21
|
+
"""
|
22
|
+
temp_dir = Path(tempfile.mkdtemp())
|
23
|
+
try:
|
24
|
+
yield temp_dir
|
25
|
+
finally:
|
26
|
+
shutil.rmtree(temp_dir)
|
27
|
+
|
28
|
+
|
29
|
+
def pdf_to_images(pdf_path: Path) -> List[Path]:
|
30
|
+
"""Convert a PDF file to a list of images.
|
31
|
+
|
32
|
+
The images are stored in a temporary directory that will be automatically
|
33
|
+
cleaned up when the process exits.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
pdf_path: Path to the PDF file
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
List[Path]: List of paths to the generated images
|
40
|
+
|
41
|
+
Note:
|
42
|
+
The temporary files will be automatically cleaned up when the process exits.
|
43
|
+
Do not rely on these files persisting after the function returns.
|
44
|
+
"""
|
45
|
+
with temporary_directory() as temp_dir:
|
46
|
+
images = convert_from_path(pdf_path=str(pdf_path), output_folder=str(temp_dir))
|
47
|
+
image_paths = []
|
48
|
+
|
49
|
+
for i, image in enumerate(images):
|
50
|
+
output_path = temp_dir / f"{pdf_path.stem}_{i}.png"
|
51
|
+
image.save(output_path)
|
52
|
+
final_path = Path(tempfile.mktemp(suffix=".png"))
|
53
|
+
shutil.move(output_path, final_path)
|
54
|
+
image_paths.append(final_path)
|
55
|
+
|
56
|
+
return image_paths
|
57
|
+
|
58
|
+
|
59
|
+
def image_to_base64(image_path: Union[str, Path]) -> tuple[str, str]:
|
60
|
+
"""Convert an image to base64 and return the base64 data and content type.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
image_path: Path to the image file
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
tuple[str, str]: (base64_data, content_type)
|
67
|
+
|
68
|
+
Raises:
|
69
|
+
ValueError: If the image cannot be read or has an unsupported format
|
70
|
+
"""
|
71
|
+
import base64
|
72
|
+
|
73
|
+
try:
|
74
|
+
base64_data = base64.b64encode(Path(image_path).read_bytes()).decode("utf-8")
|
75
|
+
except Exception as e:
|
76
|
+
raise ValueError(f"Failed to read image {image_path}: {e}") from e
|
77
|
+
|
78
|
+
if isinstance(image_path, Path):
|
79
|
+
image_path_str = str(image_path)
|
80
|
+
else:
|
81
|
+
image_path_str = image_path
|
82
|
+
|
83
|
+
if image_path_str.lower().endswith(".png"):
|
84
|
+
content_type = "image/png"
|
85
|
+
elif image_path_str.lower().endswith((".jpg", ".jpeg")):
|
86
|
+
content_type = "image/jpeg"
|
87
|
+
elif image_path_str.lower().endswith(".gif"):
|
88
|
+
content_type = "image/gif"
|
89
|
+
else:
|
90
|
+
raise ValueError(
|
91
|
+
f"Unsupported image format for {image_path}; expected .png, .jpg, .jpeg, or .gif"
|
92
|
+
)
|
93
|
+
|
94
|
+
return base64_data, content_type
|
95
|
+
|
96
|
+
|
97
|
+
def with_pdf_cleanup(embed_func: Callable[..., T]) -> Callable[..., T]:
|
98
|
+
"""Decorator to handle PDF to image conversion with automatic cleanup.
|
99
|
+
|
100
|
+
This decorator handles the common pattern of:
|
101
|
+
1. Converting PDF to images
|
102
|
+
2. Passing images to an embedding function
|
103
|
+
3. Cleaning up temporary files
|
104
|
+
|
105
|
+
Args:
|
106
|
+
embed_func: Function that takes a list of image paths and returns embeddings
|
107
|
+
|
108
|
+
Returns:
|
109
|
+
Callable that takes a PDF path and returns embeddings
|
110
|
+
"""
|
111
|
+
|
112
|
+
def wrapper(*args, **kwargs) -> T:
|
113
|
+
# First argument is self for instance methods
|
114
|
+
pdf_path = args[-1] if args else kwargs.get("pdf_path")
|
115
|
+
if not pdf_path:
|
116
|
+
raise ValueError(
|
117
|
+
"PDF path must be provided as the last positional argument or as 'pdf_path' keyword argument"
|
118
|
+
)
|
119
|
+
|
120
|
+
images = [] # Initialize images as empty list
|
121
|
+
try:
|
122
|
+
images = pdf_to_images(pdf_path)
|
123
|
+
# Call the original function with the images instead of pdf_path
|
124
|
+
if args:
|
125
|
+
# For instance methods, replace the last argument (pdf_path) with images
|
126
|
+
args = list(args)
|
127
|
+
args[-1] = images
|
128
|
+
else:
|
129
|
+
kwargs["pdf_path"] = images
|
130
|
+
return embed_func(*args, **kwargs)
|
131
|
+
finally:
|
132
|
+
# Clean up temporary files created by pdf_to_images
|
133
|
+
for img_path in images:
|
134
|
+
try:
|
135
|
+
if img_path.exists() and str(img_path).startswith(
|
136
|
+
tempfile.gettempdir()
|
137
|
+
):
|
138
|
+
img_path.unlink()
|
139
|
+
except Exception as e:
|
140
|
+
logger.warning(f"Failed to clean up temporary file {img_path}: {e}")
|
141
|
+
|
142
|
+
return wrapper
|
Binary file
|
@@ -114,7 +114,7 @@ def test_cohere_missing_api_key():
|
|
114
114
|
# ===============================
|
115
115
|
def test_colpali_text_embedding():
|
116
116
|
"""Test text embedding with Colpali model."""
|
117
|
-
kit = EmbedKit.colpali(model=Model.ColPali.
|
117
|
+
kit = EmbedKit.colpali(model=Model.ColPali.COLPALI_V1_3)
|
118
118
|
result = kit.embed_text("Hello world")
|
119
119
|
|
120
120
|
assert len(result.objects) == 1
|
@@ -133,7 +133,7 @@ def test_colpali_text_embedding():
|
|
133
133
|
)
|
134
134
|
def test_colpali_file_embedding(request, embed_method, file_fixture):
|
135
135
|
"""Test file embedding with Colpali model."""
|
136
|
-
kit = EmbedKit.colpali(model=Model.ColPali.
|
136
|
+
kit = EmbedKit.colpali(model=Model.ColPali.COLPALI_V1_3)
|
137
137
|
file_path = request.getfixturevalue(file_fixture)
|
138
138
|
embed_func = getattr(kit, embed_method)
|
139
139
|
result = embed_func(file_path)
|
@@ -0,0 +1,52 @@
|
|
1
|
+
import pytest
|
2
|
+
from pathlib import Path
|
3
|
+
import tempfile
|
4
|
+
from embedkit.utils import temporary_directory, pdf_to_images
|
5
|
+
|
6
|
+
|
7
|
+
@pytest.fixture
|
8
|
+
def sample_pdf_path():
|
9
|
+
"""Fixture to provide a sample PDF for testing."""
|
10
|
+
path = Path("tests/fixtures/2407.01449v6_p1.pdf")
|
11
|
+
if not path.exists():
|
12
|
+
pytest.skip(f"Test fixture not found: {path}")
|
13
|
+
return path
|
14
|
+
|
15
|
+
|
16
|
+
def test_temporary_directory_cleanup():
|
17
|
+
"""Test that temporary directory is properly cleaned up after use."""
|
18
|
+
temp_dir = None
|
19
|
+
with temporary_directory() as temp_path:
|
20
|
+
temp_dir = temp_path
|
21
|
+
# Create a test file in the temp directory
|
22
|
+
test_file = temp_path / "test.txt"
|
23
|
+
test_file.write_text("test content")
|
24
|
+
assert test_file.exists()
|
25
|
+
assert temp_path.exists()
|
26
|
+
|
27
|
+
# After the context manager exits, both the file and directory should be gone
|
28
|
+
assert not temp_dir.exists()
|
29
|
+
assert not test_file.exists()
|
30
|
+
|
31
|
+
|
32
|
+
def test_pdf_to_images_temporary_files(sample_pdf_path):
|
33
|
+
"""Test that PDF to images conversion creates and cleans up temporary files properly."""
|
34
|
+
# Convert PDF to images
|
35
|
+
image_paths = pdf_to_images(sample_pdf_path)
|
36
|
+
|
37
|
+
# Check that we got image paths
|
38
|
+
assert len(image_paths) > 0
|
39
|
+
|
40
|
+
# Verify all images exist and are in temp directory
|
41
|
+
for img_path in image_paths:
|
42
|
+
assert img_path.exists()
|
43
|
+
assert str(img_path).startswith(tempfile.gettempdir())
|
44
|
+
assert img_path.suffix == ".png"
|
45
|
+
|
46
|
+
# Verify the image is readable
|
47
|
+
assert img_path.stat().st_size > 0
|
48
|
+
|
49
|
+
# Clean up the temporary files
|
50
|
+
for img_path in image_paths:
|
51
|
+
img_path.unlink()
|
52
|
+
assert not img_path.exists()
|
@@ -349,7 +349,7 @@ wheels = [
|
|
349
349
|
|
350
350
|
[[package]]
|
351
351
|
name = "embedkit"
|
352
|
-
version = "0.1.
|
352
|
+
version = "0.1.5"
|
353
353
|
source = { editable = "." }
|
354
354
|
dependencies = [
|
355
355
|
{ name = "accelerate" },
|
@@ -377,7 +377,7 @@ requires-dist = [
|
|
377
377
|
{ name = "pdf2image", specifier = ">=1.17.0" },
|
378
378
|
{ name = "pillow", specifier = ">=11.2.1" },
|
379
379
|
{ name = "torch", specifier = "<=2.5" },
|
380
|
-
{ name = "transformers" },
|
380
|
+
{ name = "transformers", specifier = ">=4.46.2" },
|
381
381
|
]
|
382
382
|
|
383
383
|
[package.metadata.requires-dev]
|
@@ -1,53 +0,0 @@
|
|
1
|
-
# ./src/embedkit/base.py
|
2
|
-
"""Base classes for EmbedKit."""
|
3
|
-
|
4
|
-
from abc import ABC, abstractmethod
|
5
|
-
from typing import Union, List, Optional
|
6
|
-
from pathlib import Path
|
7
|
-
import numpy as np
|
8
|
-
from dataclasses import dataclass
|
9
|
-
|
10
|
-
|
11
|
-
@dataclass
|
12
|
-
class EmbeddingObject:
|
13
|
-
embedding: np.ndarray
|
14
|
-
source_b64: str = None
|
15
|
-
|
16
|
-
|
17
|
-
@dataclass
|
18
|
-
class EmbeddingResponse:
|
19
|
-
model_name: str
|
20
|
-
model_provider: str
|
21
|
-
input_type: str
|
22
|
-
objects: List[EmbeddingObject]
|
23
|
-
|
24
|
-
@property
|
25
|
-
def shape(self) -> tuple:
|
26
|
-
return self.objects[0].embedding.shape
|
27
|
-
|
28
|
-
|
29
|
-
class EmbeddingProvider(ABC):
|
30
|
-
"""Abstract base class for embedding providers."""
|
31
|
-
|
32
|
-
@abstractmethod
|
33
|
-
def embed_text(self, texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
|
34
|
-
"""Generate document text embeddings using the configured provider."""
|
35
|
-
pass
|
36
|
-
|
37
|
-
@abstractmethod
|
38
|
-
def embed_image(
|
39
|
-
self, images: Union[Path, str, List[Union[Path, str]]]
|
40
|
-
) -> EmbeddingResponse:
|
41
|
-
"""Generate image embeddings using the configured provider."""
|
42
|
-
pass
|
43
|
-
|
44
|
-
@abstractmethod
|
45
|
-
def embed_pdf(self, pdf: Union[Path, str]) -> EmbeddingResponse:
|
46
|
-
"""Generate image embeddings from PDFsusing the configured provider. Takes a single PDF file."""
|
47
|
-
pass
|
48
|
-
|
49
|
-
|
50
|
-
class EmbeddingError(Exception):
|
51
|
-
"""Base exception for embedding-related errors."""
|
52
|
-
|
53
|
-
pass
|
@@ -1,160 +0,0 @@
|
|
1
|
-
# ./src/embedkit/providers/colpali.py
|
2
|
-
"""ColPali embedding provider."""
|
3
|
-
|
4
|
-
from typing import Union, List, Optional
|
5
|
-
from pathlib import Path
|
6
|
-
import logging
|
7
|
-
import numpy as np
|
8
|
-
import torch
|
9
|
-
from PIL import Image
|
10
|
-
|
11
|
-
from ..utils import pdf_to_images, image_to_base64
|
12
|
-
from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResponse, EmbeddingObject
|
13
|
-
|
14
|
-
logger = logging.getLogger(__name__)
|
15
|
-
|
16
|
-
|
17
|
-
class ColPaliProvider(EmbeddingProvider):
|
18
|
-
"""ColPali embedding provider for document understanding."""
|
19
|
-
|
20
|
-
def __init__(
|
21
|
-
self,
|
22
|
-
model_name: str,
|
23
|
-
text_batch_size: int,
|
24
|
-
image_batch_size: int,
|
25
|
-
device: Optional[str] = None,
|
26
|
-
):
|
27
|
-
self.model_name = model_name
|
28
|
-
self.provider_name = "ColPali"
|
29
|
-
self.text_batch_size = text_batch_size
|
30
|
-
self.image_batch_size = image_batch_size
|
31
|
-
|
32
|
-
# Auto-detect device
|
33
|
-
if device is None:
|
34
|
-
if torch.cuda.is_available():
|
35
|
-
device = "cuda"
|
36
|
-
elif torch.backends.mps.is_available():
|
37
|
-
device = "mps"
|
38
|
-
else:
|
39
|
-
device = "cpu"
|
40
|
-
|
41
|
-
self.device = device
|
42
|
-
self._model = None
|
43
|
-
self._processor = None
|
44
|
-
|
45
|
-
def _load_model(self):
|
46
|
-
"""Lazy load the model."""
|
47
|
-
if self._model is None:
|
48
|
-
try:
|
49
|
-
from colpali_engine.models import ColPali, ColPaliProcessor
|
50
|
-
|
51
|
-
self._model = ColPali.from_pretrained(
|
52
|
-
self.model_name,
|
53
|
-
torch_dtype=torch.bfloat16,
|
54
|
-
device_map=self.device,
|
55
|
-
).eval()
|
56
|
-
|
57
|
-
self._processor = ColPaliProcessor.from_pretrained(self.model_name)
|
58
|
-
logger.info(f"Loaded ColPali model on {self.device}")
|
59
|
-
|
60
|
-
except ImportError as e:
|
61
|
-
raise EmbeddingError(
|
62
|
-
"ColPali not installed. Run: pip install colpali-engine"
|
63
|
-
) from e
|
64
|
-
except Exception as e:
|
65
|
-
raise EmbeddingError(f"Failed to load model: {e}") from e
|
66
|
-
|
67
|
-
def embed_text(self, texts: Union[str, List[str]]) -> EmbeddingResponse:
|
68
|
-
"""Generate embeddings for text inputs."""
|
69
|
-
self._load_model()
|
70
|
-
|
71
|
-
if isinstance(texts, str):
|
72
|
-
texts = [texts]
|
73
|
-
|
74
|
-
try:
|
75
|
-
# Process texts in batches
|
76
|
-
all_embeddings = []
|
77
|
-
|
78
|
-
for i in range(0, len(texts), self.text_batch_size):
|
79
|
-
batch_texts = texts[i : i + self.text_batch_size]
|
80
|
-
processed = self._processor.process_queries(batch_texts).to(self.device)
|
81
|
-
|
82
|
-
with torch.no_grad():
|
83
|
-
batch_embeddings = self._model(**processed)
|
84
|
-
all_embeddings.append(batch_embeddings.cpu().float().numpy())
|
85
|
-
|
86
|
-
# Concatenate all batch embeddings
|
87
|
-
final_embeddings = np.concatenate(all_embeddings, axis=0)
|
88
|
-
|
89
|
-
return EmbeddingResponse(
|
90
|
-
model_name=self.model_name,
|
91
|
-
model_provider=self.provider_name,
|
92
|
-
input_type="text",
|
93
|
-
objects=[
|
94
|
-
EmbeddingObject(
|
95
|
-
embedding=e,
|
96
|
-
) for e in final_embeddings
|
97
|
-
]
|
98
|
-
)
|
99
|
-
|
100
|
-
except Exception as e:
|
101
|
-
raise EmbeddingError(f"Failed to embed text: {e}") from e
|
102
|
-
|
103
|
-
def embed_image(
|
104
|
-
self, images: Union[Path, str, List[Union[Path, str]]]
|
105
|
-
) -> EmbeddingResponse:
|
106
|
-
"""Generate embeddings for images."""
|
107
|
-
self._load_model()
|
108
|
-
|
109
|
-
if isinstance(images, (str, Path)):
|
110
|
-
images = [Path(images)]
|
111
|
-
else:
|
112
|
-
images = [Path(img) for img in images]
|
113
|
-
|
114
|
-
try:
|
115
|
-
# Process images in batches
|
116
|
-
all_embeddings = []
|
117
|
-
all_b64_images = []
|
118
|
-
|
119
|
-
for i in range(0, len(images), self.image_batch_size):
|
120
|
-
batch_images = images[i : i + self.image_batch_size]
|
121
|
-
pil_images = []
|
122
|
-
b64_images = []
|
123
|
-
|
124
|
-
for img_path in batch_images:
|
125
|
-
if not img_path.exists():
|
126
|
-
raise EmbeddingError(f"Image not found: {img_path}")
|
127
|
-
|
128
|
-
with Image.open(img_path) as img:
|
129
|
-
pil_images.append(img.convert("RGB"))
|
130
|
-
b64_images.append(image_to_base64(img_path))
|
131
|
-
|
132
|
-
processed = self._processor.process_images(pil_images).to(self.device)
|
133
|
-
|
134
|
-
with torch.no_grad():
|
135
|
-
batch_embeddings = self._model(**processed)
|
136
|
-
all_embeddings.append(batch_embeddings.cpu().float().numpy())
|
137
|
-
all_b64_images.extend(b64_images)
|
138
|
-
|
139
|
-
# Concatenate all batch embeddings
|
140
|
-
final_embeddings = np.concatenate(all_embeddings, axis=0)
|
141
|
-
|
142
|
-
return EmbeddingResponse(
|
143
|
-
model_name=self.model_name,
|
144
|
-
model_provider=self.provider_name,
|
145
|
-
input_type="image",
|
146
|
-
objects=[
|
147
|
-
EmbeddingObject(
|
148
|
-
embedding=final_embeddings[i],
|
149
|
-
source_b64=all_b64_images[i]
|
150
|
-
) for i in range(len(final_embeddings))
|
151
|
-
]
|
152
|
-
)
|
153
|
-
|
154
|
-
except Exception as e:
|
155
|
-
raise EmbeddingError(f"Failed to embed images: {e}") from e
|
156
|
-
|
157
|
-
def embed_pdf(self, pdf_path: Path) -> EmbeddingResponse:
|
158
|
-
"""Generate embeddings for a PDF file using ColPali API."""
|
159
|
-
images = pdf_to_images(pdf_path)
|
160
|
-
return self.embed_image(images)
|
@@ -1,48 +0,0 @@
|
|
1
|
-
from pdf2image import convert_from_path
|
2
|
-
from pathlib import Path
|
3
|
-
from .config import get_temp_dir
|
4
|
-
from typing import Union
|
5
|
-
|
6
|
-
|
7
|
-
def pdf_to_images(pdf_path: Path) -> list[Path]:
|
8
|
-
"""Convert a PDF file to a list of images."""
|
9
|
-
root_temp_dir = get_temp_dir()
|
10
|
-
img_temp_dir = root_temp_dir / "images"
|
11
|
-
img_temp_dir.mkdir(parents=True, exist_ok=True)
|
12
|
-
images = convert_from_path(pdf_path=str(pdf_path), output_folder=str(img_temp_dir))
|
13
|
-
image_paths = []
|
14
|
-
|
15
|
-
for i, image in enumerate(images):
|
16
|
-
output_path = img_temp_dir / f"{pdf_path.stem}_{i}.png"
|
17
|
-
if output_path.exists():
|
18
|
-
output_path.unlink()
|
19
|
-
|
20
|
-
image.save(output_path)
|
21
|
-
image_paths.append(output_path)
|
22
|
-
return image_paths
|
23
|
-
|
24
|
-
|
25
|
-
def image_to_base64(image_path: Union[str, Path]):
|
26
|
-
import base64
|
27
|
-
|
28
|
-
try:
|
29
|
-
base64_only = base64.b64encode(Path(image_path).read_bytes()).decode("utf-8")
|
30
|
-
except Exception as e:
|
31
|
-
raise ValueError(f"Failed to read image {image_path}: {e}") from e
|
32
|
-
|
33
|
-
if isinstance(image_path, Path):
|
34
|
-
image_path_str = str(image_path)
|
35
|
-
|
36
|
-
if image_path_str.lower().endswith(".png"):
|
37
|
-
content_type = "image/png"
|
38
|
-
elif image_path_str.lower().endswith((".jpg", ".jpeg")):
|
39
|
-
content_type = "image/jpeg"
|
40
|
-
elif image_path_str.lower().endswith(".gif"):
|
41
|
-
content_type = "image/gif"
|
42
|
-
else:
|
43
|
-
raise ValueError(
|
44
|
-
f"Unsupported image format for {image_path}; expected .png, .jpg, .jpeg, or .gif"
|
45
|
-
)
|
46
|
-
base64_image = f"data:{content_type};base64,{base64_only}"
|
47
|
-
|
48
|
-
return base64_image
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|