embedkit 0.1.0__tar.gz → 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {embedkit-0.1.0 → embedkit-0.1.2}/PKG-INFO +1 -1
- embedkit-0.1.2/main.py +97 -0
- {embedkit-0.1.0 → embedkit-0.1.2}/pyproject.toml +1 -1
- {embedkit-0.1.0 → embedkit-0.1.2}/src/embedkit/__init__.py +23 -4
- {embedkit-0.1.0 → embedkit-0.1.2}/src/embedkit/base.py +3 -4
- {embedkit-0.1.0 → embedkit-0.1.2}/src/embedkit/providers/cohere.py +45 -50
- {embedkit-0.1.0 → embedkit-0.1.2}/src/embedkit/providers/colpali.py +48 -17
- embedkit-0.1.2/src/embedkit/utils.py +48 -0
- embedkit-0.1.0/main.py +0 -78
- embedkit-0.1.0/src/embedkit/utils.py +0 -21
- {embedkit-0.1.0 → embedkit-0.1.2}/.gitignore +0 -0
- {embedkit-0.1.0 → embedkit-0.1.2}/.python-version +0 -0
- {embedkit-0.1.0 → embedkit-0.1.2}/LICENSE +0 -0
- {embedkit-0.1.0 → embedkit-0.1.2}/README.md +0 -0
- {embedkit-0.1.0 → embedkit-0.1.2}/src/embedkit/config.py +0 -0
- {embedkit-0.1.0 → embedkit-0.1.2}/src/embedkit/models.py +0 -0
- {embedkit-0.1.0 → embedkit-0.1.2}/src/embedkit/providers/__init__.py +0 -0
- {embedkit-0.1.0 → embedkit-0.1.2}/tests/conftest.py +0 -0
- {embedkit-0.1.0 → embedkit-0.1.2}/tests/fixtures/2407.01449v6_p1.pdf +0 -0
- {embedkit-0.1.0 → embedkit-0.1.2}/tests/fixtures/2407.01449v6_p1.png +0 -0
- {embedkit-0.1.0 → embedkit-0.1.2}/tests/test_embedkit.py +0 -0
- {embedkit-0.1.0 → embedkit-0.1.2}/uv.lock +0 -0
embedkit-0.1.2/main.py
ADDED
@@ -0,0 +1,97 @@
|
|
1
|
+
# ./main.py
|
2
|
+
from embedkit import EmbedKit
|
3
|
+
from embedkit.models import Model
|
4
|
+
from embedkit.providers.cohere import CohereInputType
|
5
|
+
from pathlib import Path
|
6
|
+
import os
|
7
|
+
|
8
|
+
|
9
|
+
def get_online_image(url: str) -> Path:
|
10
|
+
"""Download an image from a URL and return its local path."""
|
11
|
+
import requests
|
12
|
+
from tempfile import NamedTemporaryFile
|
13
|
+
|
14
|
+
# Add User-Agent header to comply with Wikipedia's policy
|
15
|
+
headers = {"User-Agent": "EmbedKit-Example/1.0"}
|
16
|
+
|
17
|
+
response = requests.get(url, headers=headers)
|
18
|
+
response.raise_for_status()
|
19
|
+
|
20
|
+
temp_file = NamedTemporaryFile(delete=False, suffix=".png")
|
21
|
+
temp_file.write(response.content)
|
22
|
+
temp_file.close()
|
23
|
+
|
24
|
+
return Path(temp_file.name)
|
25
|
+
|
26
|
+
|
27
|
+
def get_sample_image() -> Path:
|
28
|
+
"""Get a sample image for testing."""
|
29
|
+
url = "https://upload.wikimedia.org/wikipedia/commons/b/b8/English_Wikipedia_HomePage_2001-12-20.png"
|
30
|
+
return get_online_image(url)
|
31
|
+
|
32
|
+
|
33
|
+
sample_image = get_sample_image()
|
34
|
+
|
35
|
+
sample_pdf = Path("tests/fixtures/2407.01449v6_p1.pdf")
|
36
|
+
long_pdf = Path("tmp/2407.01449v6.pdf")
|
37
|
+
|
38
|
+
kit = EmbedKit.colpali(model=Model.ColPali.V1_3, text_batch_size=16, image_batch_size=8)
|
39
|
+
|
40
|
+
results = kit.embed_text("Hello world")
|
41
|
+
assert results.shape[0] == 1
|
42
|
+
assert len(results.shape) == 3
|
43
|
+
|
44
|
+
results = kit.embed_image(sample_image)
|
45
|
+
assert results.shape[0] == 1
|
46
|
+
assert len(results.shape) == 3
|
47
|
+
assert len(results.source_images_b64) > 0
|
48
|
+
|
49
|
+
results = kit.embed_pdf(sample_pdf)
|
50
|
+
assert results.shape[0] == 1
|
51
|
+
assert len(results.shape) == 3
|
52
|
+
assert len(results.source_images_b64) > 0
|
53
|
+
|
54
|
+
results = kit.embed_pdf(long_pdf)
|
55
|
+
assert results.shape[0] == 26
|
56
|
+
assert len(results.shape) == 3
|
57
|
+
assert len(results.source_images_b64) > 0
|
58
|
+
|
59
|
+
|
60
|
+
kit = EmbedKit.cohere(
|
61
|
+
model=Model.Cohere.EMBED_V4_0,
|
62
|
+
api_key=os.getenv("COHERE_API_KEY"),
|
63
|
+
text_batch_size=64,
|
64
|
+
image_batch_size=8,
|
65
|
+
text_input_type=CohereInputType.SEARCH_QUERY,
|
66
|
+
)
|
67
|
+
|
68
|
+
results = kit.embed_text("Hello world")
|
69
|
+
assert results.shape[0] == 1
|
70
|
+
assert len(results.shape) == 2
|
71
|
+
|
72
|
+
kit = EmbedKit.cohere(
|
73
|
+
model=Model.Cohere.EMBED_V4_0,
|
74
|
+
api_key=os.getenv("COHERE_API_KEY"),
|
75
|
+
text_batch_size=64,
|
76
|
+
image_batch_size=8,
|
77
|
+
text_input_type=CohereInputType.SEARCH_DOCUMENT,
|
78
|
+
)
|
79
|
+
|
80
|
+
results = kit.embed_text("Hello world")
|
81
|
+
assert results.shape[0] == 1
|
82
|
+
assert len(results.shape) == 2
|
83
|
+
|
84
|
+
results = kit.embed_image(sample_image)
|
85
|
+
assert results.shape[0] == 1
|
86
|
+
assert len(results.shape) == 2
|
87
|
+
assert len(results.source_images_b64) > 0
|
88
|
+
|
89
|
+
results = kit.embed_pdf(sample_pdf)
|
90
|
+
assert results.shape[0] == 1
|
91
|
+
assert len(results.shape) == 2
|
92
|
+
assert len(results.source_images_b64) > 0
|
93
|
+
|
94
|
+
results = kit.embed_pdf(long_pdf)
|
95
|
+
assert results.shape[0] == 26
|
96
|
+
assert len(results.shape) == 2
|
97
|
+
assert len(results.source_images_b64) > 0
|
@@ -26,21 +26,33 @@ class EmbedKit:
|
|
26
26
|
self._provider = provider_instance
|
27
27
|
|
28
28
|
@classmethod
|
29
|
-
def colpali(
|
29
|
+
def colpali(
|
30
|
+
cls,
|
31
|
+
model: Model = Model.ColPali.V1_3,
|
32
|
+
device: Optional[str] = None,
|
33
|
+
text_batch_size: int = 32,
|
34
|
+
image_batch_size: int = 8,
|
35
|
+
):
|
30
36
|
"""
|
31
37
|
Create EmbedKit instance with ColPali provider.
|
32
38
|
|
33
39
|
Args:
|
34
40
|
model: ColPali model enum
|
35
41
|
device: Device to run on ('cuda', 'mps', 'cpu', or None for auto-detect)
|
42
|
+
text_batch_size: Batch size for text embedding generation
|
43
|
+
image_batch_size: Batch size for image embedding generation
|
36
44
|
"""
|
37
45
|
if model == Model.ColPali.V1_3:
|
38
46
|
model_name = "vidore/colpali-v1.3"
|
39
47
|
else:
|
40
48
|
raise ValueError(f"Unsupported model: {model}")
|
41
49
|
|
42
|
-
|
43
|
-
|
50
|
+
provider = ColPaliProvider(
|
51
|
+
model_name=model_name,
|
52
|
+
device=device,
|
53
|
+
text_batch_size=text_batch_size,
|
54
|
+
image_batch_size=image_batch_size,
|
55
|
+
)
|
44
56
|
return cls(provider)
|
45
57
|
|
46
58
|
@classmethod
|
@@ -48,6 +60,8 @@ class EmbedKit:
|
|
48
60
|
cls,
|
49
61
|
api_key: str,
|
50
62
|
model: Model = Model.Cohere.EMBED_V4_0,
|
63
|
+
text_batch_size: int = 32,
|
64
|
+
image_batch_size: int = 8,
|
51
65
|
text_input_type: CohereInputType = CohereInputType.SEARCH_DOCUMENT,
|
52
66
|
):
|
53
67
|
"""
|
@@ -56,6 +70,8 @@ class EmbedKit:
|
|
56
70
|
Args:
|
57
71
|
api_key: Cohere API key
|
58
72
|
model: Cohere model enum
|
73
|
+
text_batch_size: Batch size for text embedding generation
|
74
|
+
image_batch_size: Batch size for image embedding generation
|
59
75
|
input_type: Type of input for embedding (search_document or search_query)
|
60
76
|
"""
|
61
77
|
if not api_key:
|
@@ -67,7 +83,10 @@ class EmbedKit:
|
|
67
83
|
raise ValueError(f"Unsupported model: {model}")
|
68
84
|
|
69
85
|
provider = CohereProvider(
|
70
|
-
api_key=api_key, model_name=model_name,
|
86
|
+
api_key=api_key, model_name=model_name,
|
87
|
+
text_batch_size=48,
|
88
|
+
image_batch_size=8,
|
89
|
+
text_input_type=text_input_type
|
71
90
|
)
|
72
91
|
return cls(provider)
|
73
92
|
|
@@ -2,7 +2,7 @@
|
|
2
2
|
"""Base classes for EmbedKit."""
|
3
3
|
|
4
4
|
from abc import ABC, abstractmethod
|
5
|
-
from typing import Union, List
|
5
|
+
from typing import Union, List, Optional
|
6
6
|
from pathlib import Path
|
7
7
|
import numpy as np
|
8
8
|
from dataclasses import dataclass
|
@@ -14,6 +14,7 @@ class EmbeddingResult:
|
|
14
14
|
model_name: str
|
15
15
|
model_provider: str
|
16
16
|
input_type: str
|
17
|
+
source_images_b64: Optional[List[str]] = None
|
17
18
|
|
18
19
|
@property
|
19
20
|
def shape(self) -> tuple:
|
@@ -36,9 +37,7 @@ class EmbeddingProvider(ABC):
|
|
36
37
|
pass
|
37
38
|
|
38
39
|
@abstractmethod
|
39
|
-
def embed_pdf(
|
40
|
-
self, pdf: Union[Path, str]
|
41
|
-
) -> EmbeddingResult:
|
40
|
+
def embed_pdf(self, pdf: Union[Path, str]) -> EmbeddingResult:
|
42
41
|
"""Generate image embeddings from PDFsusing the configured provider. Takes a single PDF file."""
|
43
42
|
pass
|
44
43
|
|
@@ -6,7 +6,7 @@ from pathlib import Path
|
|
6
6
|
import numpy as np
|
7
7
|
from enum import Enum
|
8
8
|
|
9
|
-
from ..utils import pdf_to_images
|
9
|
+
from ..utils import pdf_to_images, image_to_base64
|
10
10
|
from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResult
|
11
11
|
|
12
12
|
|
@@ -24,10 +24,14 @@ class CohereProvider(EmbeddingProvider):
|
|
24
24
|
self,
|
25
25
|
api_key: str,
|
26
26
|
model_name: str,
|
27
|
+
text_batch_size: int,
|
28
|
+
image_batch_size: int,
|
27
29
|
text_input_type: CohereInputType = CohereInputType.SEARCH_DOCUMENT,
|
28
30
|
):
|
29
31
|
self.api_key = api_key
|
30
32
|
self.model_name = model_name
|
33
|
+
self.text_batch_size = text_batch_size
|
34
|
+
self.image_batch_size = image_batch_size
|
31
35
|
self.input_type = text_input_type
|
32
36
|
self._client = None
|
33
37
|
self.provider_name = "Cohere"
|
@@ -55,15 +59,21 @@ class CohereProvider(EmbeddingProvider):
|
|
55
59
|
texts = [texts]
|
56
60
|
|
57
61
|
try:
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
62
|
+
all_embeddings = []
|
63
|
+
|
64
|
+
# Process texts in batches
|
65
|
+
for i in range(0, len(texts), self.text_batch_size):
|
66
|
+
batch_texts = texts[i : i + self.text_batch_size]
|
67
|
+
response = client.embed(
|
68
|
+
texts=batch_texts,
|
69
|
+
model=self.model_name,
|
70
|
+
input_type=self.input_type.value,
|
71
|
+
embedding_types=["float"],
|
72
|
+
)
|
73
|
+
all_embeddings.extend(response.embeddings.float_)
|
64
74
|
|
65
75
|
return EmbeddingResult(
|
66
|
-
embeddings=np.array(
|
76
|
+
embeddings=np.array(all_embeddings),
|
67
77
|
model_name=self.model_name,
|
68
78
|
model_provider=self.provider_name,
|
69
79
|
input_type=self.input_type.value,
|
@@ -81,60 +91,45 @@ class CohereProvider(EmbeddingProvider):
|
|
81
91
|
input_type = "image"
|
82
92
|
|
83
93
|
if isinstance(images, (str, Path)):
|
84
|
-
images = [images]
|
94
|
+
images = [Path(images)]
|
95
|
+
else:
|
96
|
+
images = [Path(img) for img in images]
|
85
97
|
|
86
98
|
try:
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
raise EmbeddingError(
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
content_type = "image/gif"
|
110
|
-
else:
|
111
|
-
raise EmbeddingError(
|
112
|
-
f"Unsupported image format for {image}; expected .png, .jpg, .jpeg, or .gif"
|
113
|
-
)
|
114
|
-
base64_image = f"data:{content_type};base64,{base64_only}"
|
115
|
-
else:
|
116
|
-
raise EmbeddingError(f"Unsupported image type: {type(image)}")
|
117
|
-
|
118
|
-
b64_images.append(base64_image)
|
119
|
-
|
120
|
-
response = client.embed(
|
121
|
-
model=self.model_name,
|
122
|
-
input_type="image",
|
123
|
-
images=b64_images,
|
124
|
-
embedding_types=["float"],
|
125
|
-
)
|
99
|
+
all_embeddings = []
|
100
|
+
all_b64_images = []
|
101
|
+
|
102
|
+
# Process images in batches
|
103
|
+
for i in range(0, len(images), self.image_batch_size):
|
104
|
+
batch_images = images[i : i + self.image_batch_size]
|
105
|
+
b64_images = []
|
106
|
+
|
107
|
+
for image in batch_images:
|
108
|
+
if not image.exists():
|
109
|
+
raise EmbeddingError(f"Image not found: {image}")
|
110
|
+
b64_images.append(image_to_base64(image))
|
111
|
+
|
112
|
+
response = client.embed(
|
113
|
+
model=self.model_name,
|
114
|
+
input_type="image",
|
115
|
+
images=b64_images,
|
116
|
+
embedding_types=["float"],
|
117
|
+
)
|
118
|
+
|
119
|
+
all_embeddings.extend(response.embeddings.float_)
|
120
|
+
all_b64_images.extend(b64_images)
|
126
121
|
|
127
122
|
return EmbeddingResult(
|
128
|
-
embeddings=np.array(
|
123
|
+
embeddings=np.array(all_embeddings),
|
129
124
|
model_name=self.model_name,
|
130
125
|
model_provider=self.provider_name,
|
131
126
|
input_type=input_type,
|
127
|
+
source_images_b64=all_b64_images,
|
132
128
|
)
|
133
129
|
|
134
130
|
except Exception as e:
|
135
131
|
raise EmbeddingError(f"Failed to embed image with Cohere: {e}") from e
|
136
132
|
|
137
|
-
|
138
133
|
def embed_pdf(self, pdf_path: Path) -> EmbeddingResult:
|
139
134
|
"""Generate embeddings for a PDF file using Cohere API."""
|
140
135
|
image_paths = pdf_to_images(pdf_path)
|
@@ -8,7 +8,7 @@ import numpy as np
|
|
8
8
|
import torch
|
9
9
|
from PIL import Image
|
10
10
|
|
11
|
-
from ..utils import pdf_to_images
|
11
|
+
from ..utils import pdf_to_images, image_to_base64
|
12
12
|
from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResult
|
13
13
|
|
14
14
|
logger = logging.getLogger(__name__)
|
@@ -17,9 +17,17 @@ logger = logging.getLogger(__name__)
|
|
17
17
|
class ColPaliProvider(EmbeddingProvider):
|
18
18
|
"""ColPali embedding provider for document understanding."""
|
19
19
|
|
20
|
-
def __init__(
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
model_name: str,
|
23
|
+
text_batch_size: int,
|
24
|
+
image_batch_size: int,
|
25
|
+
device: Optional[str] = None,
|
26
|
+
):
|
21
27
|
self.model_name = model_name
|
22
28
|
self.provider_name = "ColPali"
|
29
|
+
self.text_batch_size = text_batch_size
|
30
|
+
self.image_batch_size = image_batch_size
|
23
31
|
|
24
32
|
# Auto-detect device
|
25
33
|
if device is None:
|
@@ -64,13 +72,22 @@ class ColPaliProvider(EmbeddingProvider):
|
|
64
72
|
texts = [texts]
|
65
73
|
|
66
74
|
try:
|
67
|
-
|
75
|
+
# Process texts in batches
|
76
|
+
all_embeddings = []
|
68
77
|
|
69
|
-
|
70
|
-
|
78
|
+
for i in range(0, len(texts), self.text_batch_size):
|
79
|
+
batch_texts = texts[i : i + self.text_batch_size]
|
80
|
+
processed = self._processor.process_queries(batch_texts).to(self.device)
|
81
|
+
|
82
|
+
with torch.no_grad():
|
83
|
+
batch_embeddings = self._model(**processed)
|
84
|
+
all_embeddings.append(batch_embeddings.cpu().float().numpy())
|
85
|
+
|
86
|
+
# Concatenate all batch embeddings
|
87
|
+
final_embeddings = np.concatenate(all_embeddings, axis=0)
|
71
88
|
|
72
89
|
return EmbeddingResult(
|
73
|
-
embeddings=
|
90
|
+
embeddings=final_embeddings,
|
74
91
|
model_name=self.model_name,
|
75
92
|
model_provider=self.provider_name,
|
76
93
|
input_type="text",
|
@@ -91,30 +108,44 @@ class ColPaliProvider(EmbeddingProvider):
|
|
91
108
|
images = [Path(img) for img in images]
|
92
109
|
|
93
110
|
try:
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
111
|
+
# Process images in batches
|
112
|
+
all_embeddings = []
|
113
|
+
all_b64_images = []
|
114
|
+
|
115
|
+
for i in range(0, len(images), self.image_batch_size):
|
116
|
+
batch_images = images[i : i + self.image_batch_size]
|
117
|
+
pil_images = []
|
118
|
+
b64_images = []
|
98
119
|
|
99
|
-
|
100
|
-
|
120
|
+
for img_path in batch_images:
|
121
|
+
if not img_path.exists():
|
122
|
+
raise EmbeddingError(f"Image not found: {img_path}")
|
101
123
|
|
102
|
-
|
124
|
+
with Image.open(img_path) as img:
|
125
|
+
pil_images.append(img.convert("RGB"))
|
126
|
+
b64_images.append(image_to_base64(img_path))
|
103
127
|
|
104
|
-
|
105
|
-
|
128
|
+
processed = self._processor.process_images(pil_images).to(self.device)
|
129
|
+
|
130
|
+
with torch.no_grad():
|
131
|
+
batch_embeddings = self._model(**processed)
|
132
|
+
all_embeddings.append(batch_embeddings.cpu().float().numpy())
|
133
|
+
all_b64_images.extend(b64_images)
|
134
|
+
|
135
|
+
# Concatenate all batch embeddings
|
136
|
+
final_embeddings = np.concatenate(all_embeddings, axis=0)
|
106
137
|
|
107
138
|
return EmbeddingResult(
|
108
|
-
embeddings=
|
139
|
+
embeddings=final_embeddings,
|
109
140
|
model_name=self.model_name,
|
110
141
|
model_provider=self.provider_name,
|
111
142
|
input_type="image",
|
143
|
+
source_images_b64=all_b64_images,
|
112
144
|
)
|
113
145
|
|
114
146
|
except Exception as e:
|
115
147
|
raise EmbeddingError(f"Failed to embed images: {e}") from e
|
116
148
|
|
117
|
-
|
118
149
|
def embed_pdf(self, pdf_path: Path) -> EmbeddingResult:
|
119
150
|
"""Generate embeddings for a PDF file using ColPali API."""
|
120
151
|
images = pdf_to_images(pdf_path)
|
@@ -0,0 +1,48 @@
|
|
1
|
+
from pdf2image import convert_from_path
|
2
|
+
from pathlib import Path
|
3
|
+
from .config import get_temp_dir
|
4
|
+
from typing import Union
|
5
|
+
|
6
|
+
|
7
|
+
def pdf_to_images(pdf_path: Path) -> list[Path]:
|
8
|
+
"""Convert a PDF file to a list of images."""
|
9
|
+
root_temp_dir = get_temp_dir()
|
10
|
+
img_temp_dir = root_temp_dir / "images"
|
11
|
+
img_temp_dir.mkdir(parents=True, exist_ok=True)
|
12
|
+
images = convert_from_path(pdf_path=str(pdf_path), output_folder=str(img_temp_dir))
|
13
|
+
image_paths = []
|
14
|
+
|
15
|
+
for i, image in enumerate(images):
|
16
|
+
output_path = img_temp_dir / f"{pdf_path.stem}_{i}.png"
|
17
|
+
if output_path.exists():
|
18
|
+
output_path.unlink()
|
19
|
+
|
20
|
+
image.save(output_path)
|
21
|
+
image_paths.append(output_path)
|
22
|
+
return image_paths
|
23
|
+
|
24
|
+
|
25
|
+
def image_to_base64(image_path: Union[str, Path]):
|
26
|
+
import base64
|
27
|
+
|
28
|
+
try:
|
29
|
+
base64_only = base64.b64encode(Path(image_path).read_bytes()).decode("utf-8")
|
30
|
+
except Exception as e:
|
31
|
+
raise ValueError(f"Failed to read image {image_path}: {e}") from e
|
32
|
+
|
33
|
+
if isinstance(image_path, Path):
|
34
|
+
image_path_str = str(image_path)
|
35
|
+
|
36
|
+
if image_path_str.lower().endswith(".png"):
|
37
|
+
content_type = "image/png"
|
38
|
+
elif image_path_str.lower().endswith((".jpg", ".jpeg")):
|
39
|
+
content_type = "image/jpeg"
|
40
|
+
elif image_path_str.lower().endswith(".gif"):
|
41
|
+
content_type = "image/gif"
|
42
|
+
else:
|
43
|
+
raise ValueError(
|
44
|
+
f"Unsupported image format for {image_path}; expected .png, .jpg, .jpeg, or .gif"
|
45
|
+
)
|
46
|
+
base64_image = f"data:{content_type};base64,{base64_only}"
|
47
|
+
|
48
|
+
return base64_image
|
embedkit-0.1.0/main.py
DELETED
@@ -1,78 +0,0 @@
|
|
1
|
-
# ./main.py
|
2
|
-
from embedkit import EmbedKit
|
3
|
-
from embedkit.models import Model
|
4
|
-
from embedkit.providers.cohere import CohereInputType
|
5
|
-
from pathlib import Path
|
6
|
-
import os
|
7
|
-
|
8
|
-
|
9
|
-
def get_online_image(url: str) -> Path:
|
10
|
-
"""Download an image from a URL and return its local path."""
|
11
|
-
import requests
|
12
|
-
from tempfile import NamedTemporaryFile
|
13
|
-
|
14
|
-
# Add User-Agent header to comply with Wikipedia's policy
|
15
|
-
headers = {"User-Agent": "EmbedKit-Example/1.0"}
|
16
|
-
|
17
|
-
response = requests.get(url, headers=headers)
|
18
|
-
response.raise_for_status()
|
19
|
-
|
20
|
-
temp_file = NamedTemporaryFile(delete=False, suffix=".png")
|
21
|
-
temp_file.write(response.content)
|
22
|
-
temp_file.close()
|
23
|
-
|
24
|
-
return Path(temp_file.name)
|
25
|
-
|
26
|
-
|
27
|
-
def get_sample_image() -> Path:
|
28
|
-
"""Get a sample image for testing."""
|
29
|
-
url = "https://upload.wikimedia.org/wikipedia/commons/b/b8/English_Wikipedia_HomePage_2001-12-20.png"
|
30
|
-
return get_online_image(url)
|
31
|
-
|
32
|
-
|
33
|
-
sample_image = get_sample_image()
|
34
|
-
|
35
|
-
sample_pdf = Path("tests/fixtures/2407.01449v6_p1.pdf")
|
36
|
-
|
37
|
-
kit = EmbedKit.colpali(model=Model.ColPali.V1_3)
|
38
|
-
|
39
|
-
embeddings = kit.embed_text("Hello world")
|
40
|
-
assert embeddings.shape[0] == 1
|
41
|
-
assert len(embeddings.shape) == 3
|
42
|
-
|
43
|
-
embeddings = kit.embed_image(sample_image)
|
44
|
-
assert embeddings.shape[0] == 1
|
45
|
-
assert len(embeddings.shape) == 3
|
46
|
-
|
47
|
-
embeddings = kit.embed_pdf(sample_pdf)
|
48
|
-
assert embeddings.shape[0] == 1
|
49
|
-
assert len(embeddings.shape) == 3
|
50
|
-
|
51
|
-
|
52
|
-
kit = EmbedKit.cohere(
|
53
|
-
model=Model.Cohere.EMBED_V4_0,
|
54
|
-
api_key=os.getenv("COHERE_API_KEY"),
|
55
|
-
text_input_type=CohereInputType.SEARCH_QUERY,
|
56
|
-
)
|
57
|
-
|
58
|
-
embeddings = kit.embed_text("Hello world")
|
59
|
-
assert embeddings.shape[0] == 1
|
60
|
-
assert len(embeddings.shape) == 2
|
61
|
-
|
62
|
-
kit = EmbedKit.cohere(
|
63
|
-
model=Model.Cohere.EMBED_V4_0,
|
64
|
-
api_key=os.getenv("COHERE_API_KEY"),
|
65
|
-
text_input_type=CohereInputType.SEARCH_DOCUMENT,
|
66
|
-
)
|
67
|
-
|
68
|
-
embeddings = kit.embed_text("Hello world")
|
69
|
-
assert embeddings.shape[0] == 1
|
70
|
-
assert len(embeddings.shape) == 2
|
71
|
-
|
72
|
-
embeddings = kit.embed_image(sample_image)
|
73
|
-
assert embeddings.shape[0] == 1
|
74
|
-
assert len(embeddings.shape) == 2
|
75
|
-
|
76
|
-
embeddings = kit.embed_pdf(sample_pdf)
|
77
|
-
assert embeddings.shape[0] == 1
|
78
|
-
assert len(embeddings.shape) == 2
|
@@ -1,21 +0,0 @@
|
|
1
|
-
from pdf2image import convert_from_path
|
2
|
-
from pathlib import Path
|
3
|
-
from .config import get_temp_dir
|
4
|
-
|
5
|
-
|
6
|
-
def pdf_to_images(pdf_path: Path) -> list[Path]:
|
7
|
-
"""Convert a PDF file to a list of images."""
|
8
|
-
root_temp_dir = get_temp_dir()
|
9
|
-
img_temp_dir = root_temp_dir / "images"
|
10
|
-
img_temp_dir.mkdir(parents=True, exist_ok=True)
|
11
|
-
images = convert_from_path(pdf_path=str(pdf_path), output_folder=str(img_temp_dir))
|
12
|
-
image_paths = []
|
13
|
-
|
14
|
-
for i, image in enumerate(images):
|
15
|
-
output_path = img_temp_dir / f"{pdf_path.stem}_{i}.png"
|
16
|
-
if output_path.exists():
|
17
|
-
output_path.unlink()
|
18
|
-
|
19
|
-
image.save(output_path)
|
20
|
-
image_paths.append(output_path)
|
21
|
-
return image_paths
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|