embedkit 0.1.0__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: embedkit
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: A simple toolkit for generating vector embeddings across multiple providers and models
5
5
  Author-email: JP Hwang <me@jphwang.com>
6
6
  License: MIT
embedkit-0.1.2/main.py ADDED
@@ -0,0 +1,97 @@
1
+ # ./main.py
2
+ from embedkit import EmbedKit
3
+ from embedkit.models import Model
4
+ from embedkit.providers.cohere import CohereInputType
5
+ from pathlib import Path
6
+ import os
7
+
8
+
9
+ def get_online_image(url: str) -> Path:
10
+ """Download an image from a URL and return its local path."""
11
+ import requests
12
+ from tempfile import NamedTemporaryFile
13
+
14
+ # Add User-Agent header to comply with Wikipedia's policy
15
+ headers = {"User-Agent": "EmbedKit-Example/1.0"}
16
+
17
+ response = requests.get(url, headers=headers)
18
+ response.raise_for_status()
19
+
20
+ temp_file = NamedTemporaryFile(delete=False, suffix=".png")
21
+ temp_file.write(response.content)
22
+ temp_file.close()
23
+
24
+ return Path(temp_file.name)
25
+
26
+
27
+ def get_sample_image() -> Path:
28
+ """Get a sample image for testing."""
29
+ url = "https://upload.wikimedia.org/wikipedia/commons/b/b8/English_Wikipedia_HomePage_2001-12-20.png"
30
+ return get_online_image(url)
31
+
32
+
33
+ sample_image = get_sample_image()
34
+
35
+ sample_pdf = Path("tests/fixtures/2407.01449v6_p1.pdf")
36
+ long_pdf = Path("tmp/2407.01449v6.pdf")
37
+
38
+ kit = EmbedKit.colpali(model=Model.ColPali.V1_3, text_batch_size=16, image_batch_size=8)
39
+
40
+ results = kit.embed_text("Hello world")
41
+ assert results.shape[0] == 1
42
+ assert len(results.shape) == 3
43
+
44
+ results = kit.embed_image(sample_image)
45
+ assert results.shape[0] == 1
46
+ assert len(results.shape) == 3
47
+ assert len(results.source_images_b64) > 0
48
+
49
+ results = kit.embed_pdf(sample_pdf)
50
+ assert results.shape[0] == 1
51
+ assert len(results.shape) == 3
52
+ assert len(results.source_images_b64) > 0
53
+
54
+ results = kit.embed_pdf(long_pdf)
55
+ assert results.shape[0] == 26
56
+ assert len(results.shape) == 3
57
+ assert len(results.source_images_b64) > 0
58
+
59
+
60
+ kit = EmbedKit.cohere(
61
+ model=Model.Cohere.EMBED_V4_0,
62
+ api_key=os.getenv("COHERE_API_KEY"),
63
+ text_batch_size=64,
64
+ image_batch_size=8,
65
+ text_input_type=CohereInputType.SEARCH_QUERY,
66
+ )
67
+
68
+ results = kit.embed_text("Hello world")
69
+ assert results.shape[0] == 1
70
+ assert len(results.shape) == 2
71
+
72
+ kit = EmbedKit.cohere(
73
+ model=Model.Cohere.EMBED_V4_0,
74
+ api_key=os.getenv("COHERE_API_KEY"),
75
+ text_batch_size=64,
76
+ image_batch_size=8,
77
+ text_input_type=CohereInputType.SEARCH_DOCUMENT,
78
+ )
79
+
80
+ results = kit.embed_text("Hello world")
81
+ assert results.shape[0] == 1
82
+ assert len(results.shape) == 2
83
+
84
+ results = kit.embed_image(sample_image)
85
+ assert results.shape[0] == 1
86
+ assert len(results.shape) == 2
87
+ assert len(results.source_images_b64) > 0
88
+
89
+ results = kit.embed_pdf(sample_pdf)
90
+ assert results.shape[0] == 1
91
+ assert len(results.shape) == 2
92
+ assert len(results.source_images_b64) > 0
93
+
94
+ results = kit.embed_pdf(long_pdf)
95
+ assert results.shape[0] == 26
96
+ assert len(results.shape) == 2
97
+ assert len(results.source_images_b64) > 0
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "embedkit"
3
- version = "0.1.0"
3
+ version = "0.1.2"
4
4
  description = "A simple toolkit for generating vector embeddings across multiple providers and models"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10"
@@ -26,21 +26,33 @@ class EmbedKit:
26
26
  self._provider = provider_instance
27
27
 
28
28
  @classmethod
29
- def colpali(cls, model: Model = Model.ColPali.V1_3, device: Optional[str] = None):
29
+ def colpali(
30
+ cls,
31
+ model: Model = Model.ColPali.V1_3,
32
+ device: Optional[str] = None,
33
+ text_batch_size: int = 32,
34
+ image_batch_size: int = 8,
35
+ ):
30
36
  """
31
37
  Create EmbedKit instance with ColPali provider.
32
38
 
33
39
  Args:
34
40
  model: ColPali model enum
35
41
  device: Device to run on ('cuda', 'mps', 'cpu', or None for auto-detect)
42
+ text_batch_size: Batch size for text embedding generation
43
+ image_batch_size: Batch size for image embedding generation
36
44
  """
37
45
  if model == Model.ColPali.V1_3:
38
46
  model_name = "vidore/colpali-v1.3"
39
47
  else:
40
48
  raise ValueError(f"Unsupported model: {model}")
41
49
 
42
-
43
- provider = ColPaliProvider(model_name=model_name, device=device)
50
+ provider = ColPaliProvider(
51
+ model_name=model_name,
52
+ device=device,
53
+ text_batch_size=text_batch_size,
54
+ image_batch_size=image_batch_size,
55
+ )
44
56
  return cls(provider)
45
57
 
46
58
  @classmethod
@@ -48,6 +60,8 @@ class EmbedKit:
48
60
  cls,
49
61
  api_key: str,
50
62
  model: Model = Model.Cohere.EMBED_V4_0,
63
+ text_batch_size: int = 32,
64
+ image_batch_size: int = 8,
51
65
  text_input_type: CohereInputType = CohereInputType.SEARCH_DOCUMENT,
52
66
  ):
53
67
  """
@@ -56,6 +70,8 @@ class EmbedKit:
56
70
  Args:
57
71
  api_key: Cohere API key
58
72
  model: Cohere model enum
73
+ text_batch_size: Batch size for text embedding generation
74
+ image_batch_size: Batch size for image embedding generation
59
75
  input_type: Type of input for embedding (search_document or search_query)
60
76
  """
61
77
  if not api_key:
@@ -67,7 +83,10 @@ class EmbedKit:
67
83
  raise ValueError(f"Unsupported model: {model}")
68
84
 
69
85
  provider = CohereProvider(
70
- api_key=api_key, model_name=model_name, text_input_type=text_input_type
86
+ api_key=api_key, model_name=model_name,
87
+ text_batch_size=48,
88
+ image_batch_size=8,
89
+ text_input_type=text_input_type
71
90
  )
72
91
  return cls(provider)
73
92
 
@@ -2,7 +2,7 @@
2
2
  """Base classes for EmbedKit."""
3
3
 
4
4
  from abc import ABC, abstractmethod
5
- from typing import Union, List
5
+ from typing import Union, List, Optional
6
6
  from pathlib import Path
7
7
  import numpy as np
8
8
  from dataclasses import dataclass
@@ -14,6 +14,7 @@ class EmbeddingResult:
14
14
  model_name: str
15
15
  model_provider: str
16
16
  input_type: str
17
+ source_images_b64: Optional[List[str]] = None
17
18
 
18
19
  @property
19
20
  def shape(self) -> tuple:
@@ -36,9 +37,7 @@ class EmbeddingProvider(ABC):
36
37
  pass
37
38
 
38
39
  @abstractmethod
39
- def embed_pdf(
40
- self, pdf: Union[Path, str]
41
- ) -> EmbeddingResult:
40
+ def embed_pdf(self, pdf: Union[Path, str]) -> EmbeddingResult:
42
41
  """Generate image embeddings from PDFsusing the configured provider. Takes a single PDF file."""
43
42
  pass
44
43
 
@@ -6,7 +6,7 @@ from pathlib import Path
6
6
  import numpy as np
7
7
  from enum import Enum
8
8
 
9
- from ..utils import pdf_to_images
9
+ from ..utils import pdf_to_images, image_to_base64
10
10
  from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResult
11
11
 
12
12
 
@@ -24,10 +24,14 @@ class CohereProvider(EmbeddingProvider):
24
24
  self,
25
25
  api_key: str,
26
26
  model_name: str,
27
+ text_batch_size: int,
28
+ image_batch_size: int,
27
29
  text_input_type: CohereInputType = CohereInputType.SEARCH_DOCUMENT,
28
30
  ):
29
31
  self.api_key = api_key
30
32
  self.model_name = model_name
33
+ self.text_batch_size = text_batch_size
34
+ self.image_batch_size = image_batch_size
31
35
  self.input_type = text_input_type
32
36
  self._client = None
33
37
  self.provider_name = "Cohere"
@@ -55,15 +59,21 @@ class CohereProvider(EmbeddingProvider):
55
59
  texts = [texts]
56
60
 
57
61
  try:
58
- response = client.embed(
59
- texts=texts,
60
- model=self.model_name,
61
- input_type=self.input_type.value,
62
- embedding_types=["float"],
63
- )
62
+ all_embeddings = []
63
+
64
+ # Process texts in batches
65
+ for i in range(0, len(texts), self.text_batch_size):
66
+ batch_texts = texts[i : i + self.text_batch_size]
67
+ response = client.embed(
68
+ texts=batch_texts,
69
+ model=self.model_name,
70
+ input_type=self.input_type.value,
71
+ embedding_types=["float"],
72
+ )
73
+ all_embeddings.extend(response.embeddings.float_)
64
74
 
65
75
  return EmbeddingResult(
66
- embeddings=np.array(response.embeddings.float_),
76
+ embeddings=np.array(all_embeddings),
67
77
  model_name=self.model_name,
68
78
  model_provider=self.provider_name,
69
79
  input_type=self.input_type.value,
@@ -81,60 +91,45 @@ class CohereProvider(EmbeddingProvider):
81
91
  input_type = "image"
82
92
 
83
93
  if isinstance(images, (str, Path)):
84
- images = [images]
94
+ images = [Path(images)]
95
+ else:
96
+ images = [Path(img) for img in images]
85
97
 
86
98
  try:
87
- import base64
88
-
89
- b64_images = []
90
- for image in images:
91
- if isinstance(image, (Path, str)):
92
- try:
93
- base64_only = base64.b64encode(Path(image).read_bytes()).decode(
94
- "utf-8"
95
- )
96
- except Exception as e:
97
- raise EmbeddingError(
98
- f"Failed to read image {image}: {e}"
99
- ) from e
100
-
101
- if isinstance(image, Path):
102
- image = str(image)
103
-
104
- if image.lower().endswith(".png"):
105
- content_type = "image/png"
106
- elif image.lower().endswith((".jpg", ".jpeg")):
107
- content_type = "image/jpeg"
108
- elif image.lower().endswith(".gif"):
109
- content_type = "image/gif"
110
- else:
111
- raise EmbeddingError(
112
- f"Unsupported image format for {image}; expected .png, .jpg, .jpeg, or .gif"
113
- )
114
- base64_image = f"data:{content_type};base64,{base64_only}"
115
- else:
116
- raise EmbeddingError(f"Unsupported image type: {type(image)}")
117
-
118
- b64_images.append(base64_image)
119
-
120
- response = client.embed(
121
- model=self.model_name,
122
- input_type="image",
123
- images=b64_images,
124
- embedding_types=["float"],
125
- )
99
+ all_embeddings = []
100
+ all_b64_images = []
101
+
102
+ # Process images in batches
103
+ for i in range(0, len(images), self.image_batch_size):
104
+ batch_images = images[i : i + self.image_batch_size]
105
+ b64_images = []
106
+
107
+ for image in batch_images:
108
+ if not image.exists():
109
+ raise EmbeddingError(f"Image not found: {image}")
110
+ b64_images.append(image_to_base64(image))
111
+
112
+ response = client.embed(
113
+ model=self.model_name,
114
+ input_type="image",
115
+ images=b64_images,
116
+ embedding_types=["float"],
117
+ )
118
+
119
+ all_embeddings.extend(response.embeddings.float_)
120
+ all_b64_images.extend(b64_images)
126
121
 
127
122
  return EmbeddingResult(
128
- embeddings=np.array(response.embeddings.float_),
123
+ embeddings=np.array(all_embeddings),
129
124
  model_name=self.model_name,
130
125
  model_provider=self.provider_name,
131
126
  input_type=input_type,
127
+ source_images_b64=all_b64_images,
132
128
  )
133
129
 
134
130
  except Exception as e:
135
131
  raise EmbeddingError(f"Failed to embed image with Cohere: {e}") from e
136
132
 
137
-
138
133
  def embed_pdf(self, pdf_path: Path) -> EmbeddingResult:
139
134
  """Generate embeddings for a PDF file using Cohere API."""
140
135
  image_paths = pdf_to_images(pdf_path)
@@ -8,7 +8,7 @@ import numpy as np
8
8
  import torch
9
9
  from PIL import Image
10
10
 
11
- from ..utils import pdf_to_images
11
+ from ..utils import pdf_to_images, image_to_base64
12
12
  from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResult
13
13
 
14
14
  logger = logging.getLogger(__name__)
@@ -17,9 +17,17 @@ logger = logging.getLogger(__name__)
17
17
  class ColPaliProvider(EmbeddingProvider):
18
18
  """ColPali embedding provider for document understanding."""
19
19
 
20
- def __init__(self, model_name: str, device: Optional[str] = None):
20
+ def __init__(
21
+ self,
22
+ model_name: str,
23
+ text_batch_size: int,
24
+ image_batch_size: int,
25
+ device: Optional[str] = None,
26
+ ):
21
27
  self.model_name = model_name
22
28
  self.provider_name = "ColPali"
29
+ self.text_batch_size = text_batch_size
30
+ self.image_batch_size = image_batch_size
23
31
 
24
32
  # Auto-detect device
25
33
  if device is None:
@@ -64,13 +72,22 @@ class ColPaliProvider(EmbeddingProvider):
64
72
  texts = [texts]
65
73
 
66
74
  try:
67
- processed = self._processor.process_queries(texts).to(self.device)
75
+ # Process texts in batches
76
+ all_embeddings = []
68
77
 
69
- with torch.no_grad():
70
- embeddings = self._model(**processed)
78
+ for i in range(0, len(texts), self.text_batch_size):
79
+ batch_texts = texts[i : i + self.text_batch_size]
80
+ processed = self._processor.process_queries(batch_texts).to(self.device)
81
+
82
+ with torch.no_grad():
83
+ batch_embeddings = self._model(**processed)
84
+ all_embeddings.append(batch_embeddings.cpu().float().numpy())
85
+
86
+ # Concatenate all batch embeddings
87
+ final_embeddings = np.concatenate(all_embeddings, axis=0)
71
88
 
72
89
  return EmbeddingResult(
73
- embeddings=embeddings.cpu().float().numpy(),
90
+ embeddings=final_embeddings,
74
91
  model_name=self.model_name,
75
92
  model_provider=self.provider_name,
76
93
  input_type="text",
@@ -91,30 +108,44 @@ class ColPaliProvider(EmbeddingProvider):
91
108
  images = [Path(img) for img in images]
92
109
 
93
110
  try:
94
- pil_images = []
95
- for img_path in images:
96
- if not img_path.exists():
97
- raise EmbeddingError(f"Image not found: {img_path}")
111
+ # Process images in batches
112
+ all_embeddings = []
113
+ all_b64_images = []
114
+
115
+ for i in range(0, len(images), self.image_batch_size):
116
+ batch_images = images[i : i + self.image_batch_size]
117
+ pil_images = []
118
+ b64_images = []
98
119
 
99
- with Image.open(img_path) as img:
100
- pil_images.append(img.convert("RGB"))
120
+ for img_path in batch_images:
121
+ if not img_path.exists():
122
+ raise EmbeddingError(f"Image not found: {img_path}")
101
123
 
102
- processed = self._processor.process_images(pil_images).to(self.device)
124
+ with Image.open(img_path) as img:
125
+ pil_images.append(img.convert("RGB"))
126
+ b64_images.append(image_to_base64(img_path))
103
127
 
104
- with torch.no_grad():
105
- embeddings = self._model(**processed)
128
+ processed = self._processor.process_images(pil_images).to(self.device)
129
+
130
+ with torch.no_grad():
131
+ batch_embeddings = self._model(**processed)
132
+ all_embeddings.append(batch_embeddings.cpu().float().numpy())
133
+ all_b64_images.extend(b64_images)
134
+
135
+ # Concatenate all batch embeddings
136
+ final_embeddings = np.concatenate(all_embeddings, axis=0)
106
137
 
107
138
  return EmbeddingResult(
108
- embeddings=embeddings.cpu().float().numpy(),
139
+ embeddings=final_embeddings,
109
140
  model_name=self.model_name,
110
141
  model_provider=self.provider_name,
111
142
  input_type="image",
143
+ source_images_b64=all_b64_images,
112
144
  )
113
145
 
114
146
  except Exception as e:
115
147
  raise EmbeddingError(f"Failed to embed images: {e}") from e
116
148
 
117
-
118
149
  def embed_pdf(self, pdf_path: Path) -> EmbeddingResult:
119
150
  """Generate embeddings for a PDF file using ColPali API."""
120
151
  images = pdf_to_images(pdf_path)
@@ -0,0 +1,48 @@
1
+ from pdf2image import convert_from_path
2
+ from pathlib import Path
3
+ from .config import get_temp_dir
4
+ from typing import Union
5
+
6
+
7
+ def pdf_to_images(pdf_path: Path) -> list[Path]:
8
+ """Convert a PDF file to a list of images."""
9
+ root_temp_dir = get_temp_dir()
10
+ img_temp_dir = root_temp_dir / "images"
11
+ img_temp_dir.mkdir(parents=True, exist_ok=True)
12
+ images = convert_from_path(pdf_path=str(pdf_path), output_folder=str(img_temp_dir))
13
+ image_paths = []
14
+
15
+ for i, image in enumerate(images):
16
+ output_path = img_temp_dir / f"{pdf_path.stem}_{i}.png"
17
+ if output_path.exists():
18
+ output_path.unlink()
19
+
20
+ image.save(output_path)
21
+ image_paths.append(output_path)
22
+ return image_paths
23
+
24
+
25
+ def image_to_base64(image_path: Union[str, Path]):
26
+ import base64
27
+
28
+ try:
29
+ base64_only = base64.b64encode(Path(image_path).read_bytes()).decode("utf-8")
30
+ except Exception as e:
31
+ raise ValueError(f"Failed to read image {image_path}: {e}") from e
32
+
33
+ if isinstance(image_path, Path):
34
+ image_path_str = str(image_path)
35
+
36
+ if image_path_str.lower().endswith(".png"):
37
+ content_type = "image/png"
38
+ elif image_path_str.lower().endswith((".jpg", ".jpeg")):
39
+ content_type = "image/jpeg"
40
+ elif image_path_str.lower().endswith(".gif"):
41
+ content_type = "image/gif"
42
+ else:
43
+ raise ValueError(
44
+ f"Unsupported image format for {image_path}; expected .png, .jpg, .jpeg, or .gif"
45
+ )
46
+ base64_image = f"data:{content_type};base64,{base64_only}"
47
+
48
+ return base64_image
embedkit-0.1.0/main.py DELETED
@@ -1,78 +0,0 @@
1
- # ./main.py
2
- from embedkit import EmbedKit
3
- from embedkit.models import Model
4
- from embedkit.providers.cohere import CohereInputType
5
- from pathlib import Path
6
- import os
7
-
8
-
9
- def get_online_image(url: str) -> Path:
10
- """Download an image from a URL and return its local path."""
11
- import requests
12
- from tempfile import NamedTemporaryFile
13
-
14
- # Add User-Agent header to comply with Wikipedia's policy
15
- headers = {"User-Agent": "EmbedKit-Example/1.0"}
16
-
17
- response = requests.get(url, headers=headers)
18
- response.raise_for_status()
19
-
20
- temp_file = NamedTemporaryFile(delete=False, suffix=".png")
21
- temp_file.write(response.content)
22
- temp_file.close()
23
-
24
- return Path(temp_file.name)
25
-
26
-
27
- def get_sample_image() -> Path:
28
- """Get a sample image for testing."""
29
- url = "https://upload.wikimedia.org/wikipedia/commons/b/b8/English_Wikipedia_HomePage_2001-12-20.png"
30
- return get_online_image(url)
31
-
32
-
33
- sample_image = get_sample_image()
34
-
35
- sample_pdf = Path("tests/fixtures/2407.01449v6_p1.pdf")
36
-
37
- kit = EmbedKit.colpali(model=Model.ColPali.V1_3)
38
-
39
- embeddings = kit.embed_text("Hello world")
40
- assert embeddings.shape[0] == 1
41
- assert len(embeddings.shape) == 3
42
-
43
- embeddings = kit.embed_image(sample_image)
44
- assert embeddings.shape[0] == 1
45
- assert len(embeddings.shape) == 3
46
-
47
- embeddings = kit.embed_pdf(sample_pdf)
48
- assert embeddings.shape[0] == 1
49
- assert len(embeddings.shape) == 3
50
-
51
-
52
- kit = EmbedKit.cohere(
53
- model=Model.Cohere.EMBED_V4_0,
54
- api_key=os.getenv("COHERE_API_KEY"),
55
- text_input_type=CohereInputType.SEARCH_QUERY,
56
- )
57
-
58
- embeddings = kit.embed_text("Hello world")
59
- assert embeddings.shape[0] == 1
60
- assert len(embeddings.shape) == 2
61
-
62
- kit = EmbedKit.cohere(
63
- model=Model.Cohere.EMBED_V4_0,
64
- api_key=os.getenv("COHERE_API_KEY"),
65
- text_input_type=CohereInputType.SEARCH_DOCUMENT,
66
- )
67
-
68
- embeddings = kit.embed_text("Hello world")
69
- assert embeddings.shape[0] == 1
70
- assert len(embeddings.shape) == 2
71
-
72
- embeddings = kit.embed_image(sample_image)
73
- assert embeddings.shape[0] == 1
74
- assert len(embeddings.shape) == 2
75
-
76
- embeddings = kit.embed_pdf(sample_pdf)
77
- assert embeddings.shape[0] == 1
78
- assert len(embeddings.shape) == 2
@@ -1,21 +0,0 @@
1
- from pdf2image import convert_from_path
2
- from pathlib import Path
3
- from .config import get_temp_dir
4
-
5
-
6
- def pdf_to_images(pdf_path: Path) -> list[Path]:
7
- """Convert a PDF file to a list of images."""
8
- root_temp_dir = get_temp_dir()
9
- img_temp_dir = root_temp_dir / "images"
10
- img_temp_dir.mkdir(parents=True, exist_ok=True)
11
- images = convert_from_path(pdf_path=str(pdf_path), output_folder=str(img_temp_dir))
12
- image_paths = []
13
-
14
- for i, image in enumerate(images):
15
- output_path = img_temp_dir / f"{pdf_path.stem}_{i}.png"
16
- if output_path.exists():
17
- output_path.unlink()
18
-
19
- image.save(output_path)
20
- image_paths.append(output_path)
21
- return image_paths
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes