embedkit 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
embedkit/__init__.py CHANGED
@@ -5,7 +5,6 @@ EmbedKit: A unified toolkit for generating vector embeddings.
5
5
 
6
6
  from typing import Union, List, Optional
7
7
  from pathlib import Path
8
- import numpy as np
9
8
 
10
9
  from .models import Model
11
10
  from .base import EmbeddingError, EmbeddingResponse
@@ -28,7 +27,7 @@ class EmbedKit:
28
27
  @classmethod
29
28
  def colpali(
30
29
  cls,
31
- model: Model = Model.ColPali.V1_3,
30
+ model: Model = Model.ColPali.COLPALI_V1_3,
32
31
  device: Optional[str] = None,
33
32
  text_batch_size: int = 32,
34
33
  image_batch_size: int = 8,
@@ -42,13 +41,13 @@ class EmbedKit:
42
41
  text_batch_size: Batch size for text embedding generation
43
42
  image_batch_size: Batch size for image embedding generation
44
43
  """
45
- if model == Model.ColPali.V1_3:
46
- model_name = "vidore/colpali-v1.3"
47
- else:
48
- raise ValueError(f"Unsupported model: {model}")
44
+ if not isinstance(model, Model.ColPali):
45
+ raise ValueError(
46
+ f"Unsupported model: {model}. Must be a Model.ColPali enum value."
47
+ )
49
48
 
50
49
  provider = ColPaliProvider(
51
- model_name=model_name,
50
+ model=model,
52
51
  device=device,
53
52
  text_batch_size=text_batch_size,
54
53
  image_batch_size=image_batch_size,
@@ -77,16 +76,15 @@ class EmbedKit:
77
76
  if not api_key:
78
77
  raise ValueError("API key is required")
79
78
 
80
- if model == Model.Cohere.EMBED_V4_0:
81
- model_name = "embed-v4.0"
82
- else:
79
+ if not isinstance(model, Model.Cohere):
83
80
  raise ValueError(f"Unsupported model: {model}")
84
81
 
85
82
  provider = CohereProvider(
86
- api_key=api_key, model_name=model_name,
83
+ api_key=api_key,
84
+ model=model,
87
85
  text_batch_size=text_batch_size,
88
86
  image_batch_size=image_batch_size,
89
- text_input_type=text_input_type
87
+ text_input_type=text_input_type,
90
88
  )
91
89
  return cls(provider)
92
90
 
embedkit/base.py CHANGED
@@ -7,11 +7,15 @@ from pathlib import Path
7
7
  import numpy as np
8
8
  from dataclasses import dataclass
9
9
 
10
+ from .models import Model
11
+ from .utils import with_pdf_cleanup
12
+
10
13
 
11
14
  @dataclass
12
15
  class EmbeddingObject:
13
16
  embedding: np.ndarray
14
17
  source_b64: str = None
18
+ source_content_type: str = None # e.g., "image/png", "image/jpeg"
15
19
 
16
20
 
17
21
  @dataclass
@@ -29,6 +33,67 @@ class EmbeddingResponse:
29
33
  class EmbeddingProvider(ABC):
30
34
  """Abstract base class for embedding providers."""
31
35
 
36
+ def __init__(
37
+ self,
38
+ model_name: str,
39
+ text_batch_size: int,
40
+ image_batch_size: int,
41
+ provider_name: str,
42
+ ):
43
+ self.model_name = model_name
44
+ self.provider_name = provider_name
45
+ self.text_batch_size = text_batch_size
46
+ self.image_batch_size = image_batch_size
47
+
48
+ def _normalize_text_input(self, texts: Union[str, List[str]]) -> List[str]:
49
+ """Normalize text input to a list of strings."""
50
+ if isinstance(texts, str):
51
+ return [texts]
52
+ return texts
53
+
54
+ def _normalize_image_input(
55
+ self, images: Union[Path, str, List[Union[Path, str]]]
56
+ ) -> List[Path]:
57
+ """Normalize image input to a list of Path objects."""
58
+ if isinstance(images, (str, Path)):
59
+ return [Path(images)]
60
+ return [Path(img) for img in images]
61
+
62
+ def _create_text_response(
63
+ self, embeddings: List[np.ndarray], input_type: str = "text"
64
+ ) -> EmbeddingResponse:
65
+ """Create a standardized text embedding response."""
66
+ return EmbeddingResponse(
67
+ model_name=self.model_name,
68
+ model_provider=self.provider_name,
69
+ input_type=input_type,
70
+ objects=[EmbeddingObject(embedding=e) for e in embeddings],
71
+ )
72
+
73
+ def _create_image_response(
74
+ self,
75
+ embeddings: List[np.ndarray],
76
+ b64_data: List[str],
77
+ content_types: List[str],
78
+ input_type: str = "image",
79
+ ) -> EmbeddingResponse:
80
+ """Create a standardized image embedding response."""
81
+ return EmbeddingResponse(
82
+ model_name=self.model_name,
83
+ model_provider=self.provider_name,
84
+ input_type=input_type,
85
+ objects=[
86
+ EmbeddingObject(
87
+ embedding=embedding,
88
+ source_b64=b64_data,
89
+ source_content_type=content_type,
90
+ )
91
+ for embedding, b64_data, content_type in zip(
92
+ embeddings, b64_data, content_types
93
+ )
94
+ ],
95
+ )
96
+
32
97
  @abstractmethod
33
98
  def embed_text(self, texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
34
99
  """Generate document text embeddings using the configured provider."""
@@ -41,10 +106,14 @@ class EmbeddingProvider(ABC):
41
106
  """Generate image embeddings using the configured provider."""
42
107
  pass
43
108
 
44
- @abstractmethod
45
- def embed_pdf(self, pdf: Union[Path, str]) -> EmbeddingResponse:
46
- """Generate image embeddings from PDFsusing the configured provider. Takes a single PDF file."""
47
- pass
109
+ def embed_pdf(self, pdf_path: Path) -> EmbeddingResponse:
110
+ """Generate embeddings for a PDF file."""
111
+ return self._embed_pdf_impl(pdf_path)
112
+
113
+ @with_pdf_cleanup
114
+ def _embed_pdf_impl(self, pdf_path: List[Path]) -> EmbeddingResponse:
115
+ """Internal implementation of PDF embedding with cleanup handled by decorator."""
116
+ return self.embed_image(pdf_path)
48
117
 
49
118
 
50
119
  class EmbeddingError(Exception):
embedkit/classes.py CHANGED
@@ -13,9 +13,4 @@ from . import EmbeddingResponse, EmbeddingError
13
13
  from .models import Model
14
14
  from .providers.cohere import CohereInputType
15
15
 
16
- __all__ = [
17
- "EmbeddingResponse",
18
- "EmbeddingError",
19
- "Model",
20
- "CohereInputType"
21
- ]
16
+ __all__ = ["EmbeddingResponse", "EmbeddingError", "Model", "CohereInputType"]
embedkit/models.py CHANGED
@@ -6,7 +6,13 @@ from enum import Enum
6
6
 
7
7
  class Model:
8
8
  class ColPali(Enum):
9
- V1_3 = "colpali-v1.3"
9
+ COLPALI_V1_3 = "vidore/colpali-v1.3"
10
+ COLSMOL_500M = "vidore/colSmol-500M"
11
+ COLSMOL_256M = "vidore/colSmol-256M"
10
12
 
11
13
  class Cohere(Enum):
12
14
  EMBED_V4_0 = "embed-v4.0"
15
+ EMBED_ENGLISH_V3_0 = "embed-english-v3.0"
16
+ EMBED_ENGLISH_LIGHT_V3_0 = "embed-english-light-v3.0"
17
+ EMBED_MULTILINGUAL_V3_0 = "embed-multilingual-v3.0"
18
+ EMBED_MULTILINGUAL_LIGHT_V3_0 = "embed-multilingual-light-v3.0"
@@ -5,9 +5,13 @@ from typing import Union, List
5
5
  from pathlib import Path
6
6
  import numpy as np
7
7
  from enum import Enum
8
+ import logging
8
9
 
9
- from ..utils import pdf_to_images, image_to_base64
10
- from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResponse, EmbeddingObject
10
+ from ..models import Model
11
+ from ..utils import image_to_base64
12
+ from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResponse
13
+
14
+ logger = logging.getLogger(__name__)
11
15
 
12
16
 
13
17
  class CohereInputType(Enum):
@@ -23,18 +27,20 @@ class CohereProvider(EmbeddingProvider):
23
27
  def __init__(
24
28
  self,
25
29
  api_key: str,
26
- model_name: str,
30
+ model: Model.Cohere,
27
31
  text_batch_size: int,
28
32
  image_batch_size: int,
29
33
  text_input_type: CohereInputType = CohereInputType.SEARCH_DOCUMENT,
30
34
  ):
35
+ super().__init__(
36
+ model_name=model.value,
37
+ text_batch_size=text_batch_size,
38
+ image_batch_size=image_batch_size,
39
+ provider_name="Cohere",
40
+ )
31
41
  self.api_key = api_key
32
- self.model_name = model_name
33
- self.text_batch_size = text_batch_size
34
- self.image_batch_size = image_batch_size
35
42
  self.input_type = text_input_type
36
43
  self._client = None
37
- self.provider_name = "Cohere"
38
44
 
39
45
  def _get_client(self):
40
46
  """Lazy load the Cohere client."""
@@ -54,9 +60,7 @@ class CohereProvider(EmbeddingProvider):
54
60
  def embed_text(self, texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
55
61
  """Generate text embeddings using the Cohere API."""
56
62
  client = self._get_client()
57
-
58
- if isinstance(texts, str):
59
- texts = [texts]
63
+ texts = self._normalize_text_input(texts)
60
64
 
61
65
  try:
62
66
  all_embeddings = []
@@ -72,16 +76,7 @@ class CohereProvider(EmbeddingProvider):
72
76
  )
73
77
  all_embeddings.extend(np.array(response.embeddings.float_))
74
78
 
75
- return EmbeddingResponse(
76
- model_name=self.model_name,
77
- model_provider=self.provider_name,
78
- input_type=self.input_type.value,
79
- objects=[
80
- EmbeddingObject(
81
- embedding=e,
82
- ) for e in all_embeddings
83
- ]
84
- )
79
+ return self._create_text_response(all_embeddings, self.input_type.value)
85
80
 
86
81
  except Exception as e:
87
82
  raise EmbeddingError(f"Failed to embed text with Cohere: {e}") from e
@@ -92,12 +87,9 @@ class CohereProvider(EmbeddingProvider):
92
87
  ) -> EmbeddingResponse:
93
88
  """Generate embeddings for images using Cohere API."""
94
89
  client = self._get_client()
95
- input_type = "image"
96
-
97
- if isinstance(images, (str, Path)):
98
- images = [Path(images)]
99
- else:
100
- images = [Path(img) for img in images]
90
+ images = self._normalize_image_input(images)
91
+ total_images = len(images)
92
+ logger.info(f"Starting to process {total_images} images")
101
93
 
102
94
  try:
103
95
  all_embeddings = []
@@ -106,12 +98,17 @@ class CohereProvider(EmbeddingProvider):
106
98
  # Process images in batches
107
99
  for i in range(0, len(images), self.image_batch_size):
108
100
  batch_images = images[i : i + self.image_batch_size]
101
+ logger.info(f"Processing batch {i//self.image_batch_size + 1} of {(total_images + self.image_batch_size - 1)//self.image_batch_size} ({len(batch_images)} images)")
109
102
  b64_images = []
110
103
 
111
104
  for image in batch_images:
112
105
  if not image.exists():
113
106
  raise EmbeddingError(f"Image not found: {image}")
114
- b64_images.append(image_to_base64(image))
107
+ b64_data, content_type = image_to_base64(image)
108
+ # Construct full data URI for API
109
+ data_uri = f"data:{content_type};base64,{b64_data}"
110
+ b64_images.append(data_uri)
111
+ all_b64_images.append((b64_data, content_type))
115
112
 
116
113
  response = client.embed(
117
114
  model=self.model_name,
@@ -121,24 +118,14 @@ class CohereProvider(EmbeddingProvider):
121
118
  )
122
119
 
123
120
  all_embeddings.extend(np.array(response.embeddings.float_))
124
- all_b64_images.extend(b64_images)
125
-
126
- return EmbeddingResponse(
127
- model_name=self.model_name,
128
- model_provider=self.provider_name,
129
- input_type=input_type,
130
- objects=[
131
- EmbeddingObject(
132
- embedding=all_embeddings[i],
133
- source_b64=all_b64_images[i]
134
- ) for i in range(len(all_embeddings))
135
- ]
121
+
122
+ logger.info(f"Successfully processed all {total_images} images")
123
+ return self._create_image_response(
124
+ all_embeddings,
125
+ [b64 for b64, _ in all_b64_images],
126
+ [content_type for _, content_type in all_b64_images],
136
127
  )
137
128
 
138
129
  except Exception as e:
130
+ logger.error(f"Failed to embed images: {e}")
139
131
  raise EmbeddingError(f"Failed to embed image with Cohere: {e}") from e
140
-
141
- def embed_pdf(self, pdf_path: Path) -> EmbeddingResponse:
142
- """Generate embeddings for a PDF file using Cohere API."""
143
- image_paths = pdf_to_images(pdf_path)
144
- return self.embed_image(image_paths)
@@ -8,8 +8,9 @@ import numpy as np
8
8
  import torch
9
9
  from PIL import Image
10
10
 
11
- from ..utils import pdf_to_images, image_to_base64
12
- from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResponse, EmbeddingObject
11
+ from ..models import Model
12
+ from ..utils import image_to_base64
13
+ from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResponse
13
14
 
14
15
  logger = logging.getLogger(__name__)
15
16
 
@@ -19,15 +20,17 @@ class ColPaliProvider(EmbeddingProvider):
19
20
 
20
21
  def __init__(
21
22
  self,
22
- model_name: str,
23
+ model: Model.ColPali,
23
24
  text_batch_size: int,
24
25
  image_batch_size: int,
25
26
  device: Optional[str] = None,
26
27
  ):
27
- self.model_name = model_name
28
- self.provider_name = "ColPali"
29
- self.text_batch_size = text_batch_size
30
- self.image_batch_size = image_batch_size
28
+ super().__init__(
29
+ model_name=model.value,
30
+ text_batch_size=text_batch_size,
31
+ image_batch_size=image_batch_size,
32
+ provider_name="ColPali",
33
+ )
31
34
 
32
35
  # Auto-detect device
33
36
  if device is None:
@@ -38,24 +41,43 @@ class ColPaliProvider(EmbeddingProvider):
38
41
  else:
39
42
  device = "cpu"
40
43
 
41
- self.device = device
42
- self._model = None
43
- self._processor = None
44
+ self._hf_device = device
45
+ self._hf_model = None
46
+ self._hf_processor = None
44
47
 
45
48
  def _load_model(self):
46
49
  """Lazy load the model."""
47
- if self._model is None:
50
+ if self._hf_model is None:
48
51
  try:
49
- from colpali_engine.models import ColPali, ColPaliProcessor
50
-
51
- self._model = ColPali.from_pretrained(
52
- self.model_name,
53
- torch_dtype=torch.bfloat16,
54
- device_map=self.device,
55
- ).eval()
56
-
57
- self._processor = ColPaliProcessor.from_pretrained(self.model_name)
58
- logger.info(f"Loaded ColPali model on {self.device}")
52
+ if self.model_name in [Model.ColPali.COLPALI_V1_3.value]:
53
+ from colpali_engine.models import ColPali, ColPaliProcessor
54
+
55
+ self._hf_model = ColPali.from_pretrained(
56
+ self.model_name,
57
+ torch_dtype=torch.bfloat16,
58
+ device_map=self._hf_device,
59
+ ).eval()
60
+
61
+ self._hf_processor = ColPaliProcessor.from_pretrained(self.model_name)
62
+
63
+ elif self.model_name in [
64
+ Model.ColPali.COLSMOL_500M.value,
65
+ Model.ColPali.COLSMOL_256M.value,
66
+ ]:
67
+ from colpali_engine.models import ColIdefics3, ColIdefics3Processor
68
+
69
+ self._hf_model = ColIdefics3.from_pretrained(
70
+ self.model_name,
71
+ torch_dtype=torch.bfloat16,
72
+ device_map=self._hf_device,
73
+ ).eval()
74
+ self._hf_processor = ColIdefics3Processor.from_pretrained(
75
+ self.model_name
76
+ )
77
+ else:
78
+ raise ValueError(f"Unable to load model for: {self.model_name}.")
79
+
80
+ logger.info(f"Loaded {self.model_name} on {self._hf_device}")
59
81
 
60
82
  except ImportError as e:
61
83
  raise EmbeddingError(
@@ -64,38 +86,26 @@ class ColPaliProvider(EmbeddingProvider):
64
86
  except Exception as e:
65
87
  raise EmbeddingError(f"Failed to load model: {e}") from e
66
88
 
67
- def embed_text(self, texts: Union[str, List[str]]) -> EmbeddingResponse:
89
+ def embed_text(self, texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
68
90
  """Generate embeddings for text inputs."""
69
91
  self._load_model()
70
-
71
- if isinstance(texts, str):
72
- texts = [texts]
92
+ texts = self._normalize_text_input(texts)
73
93
 
74
94
  try:
75
95
  # Process texts in batches
76
- all_embeddings = []
96
+ all_embeddings: List[np.ndarray] = []
77
97
 
78
98
  for i in range(0, len(texts), self.text_batch_size):
79
99
  batch_texts = texts[i : i + self.text_batch_size]
80
- processed = self._processor.process_queries(batch_texts).to(self.device)
100
+ processed = self._hf_processor.process_queries(batch_texts).to(self._hf_device)
81
101
 
82
102
  with torch.no_grad():
83
- batch_embeddings = self._model(**processed)
103
+ batch_embeddings = self._hf_model(**processed)
84
104
  all_embeddings.append(batch_embeddings.cpu().float().numpy())
85
105
 
86
106
  # Concatenate all batch embeddings
87
107
  final_embeddings = np.concatenate(all_embeddings, axis=0)
88
-
89
- return EmbeddingResponse(
90
- model_name=self.model_name,
91
- model_provider=self.provider_name,
92
- input_type="text",
93
- objects=[
94
- EmbeddingObject(
95
- embedding=e,
96
- ) for e in final_embeddings
97
- ]
98
- )
108
+ return self._create_text_response(final_embeddings)
99
109
 
100
110
  except Exception as e:
101
111
  raise EmbeddingError(f"Failed to embed text: {e}") from e
@@ -105,21 +115,22 @@ class ColPaliProvider(EmbeddingProvider):
105
115
  ) -> EmbeddingResponse:
106
116
  """Generate embeddings for images."""
107
117
  self._load_model()
108
-
109
- if isinstance(images, (str, Path)):
110
- images = [Path(images)]
111
- else:
112
- images = [Path(img) for img in images]
118
+ images = self._normalize_image_input(images)
119
+ total_images = len(images)
120
+ logger.info(f"Starting to process {total_images} images")
113
121
 
114
122
  try:
115
123
  # Process images in batches
116
- all_embeddings = []
117
- all_b64_images = []
124
+ all_embeddings: List[np.ndarray] = []
125
+ all_b64_data: List[str] = []
126
+ all_content_types: List[str] = []
118
127
 
119
128
  for i in range(0, len(images), self.image_batch_size):
120
129
  batch_images = images[i : i + self.image_batch_size]
130
+ logger.info(f"Processing batch {i//self.image_batch_size + 1} of {(total_images + self.image_batch_size - 1)//self.image_batch_size} ({len(batch_images)} images)")
121
131
  pil_images = []
122
- b64_images = []
132
+ batch_b64_data = []
133
+ batch_content_types = []
123
134
 
124
135
  for img_path in batch_images:
125
136
  if not img_path.exists():
@@ -127,34 +138,25 @@ class ColPaliProvider(EmbeddingProvider):
127
138
 
128
139
  with Image.open(img_path) as img:
129
140
  pil_images.append(img.convert("RGB"))
130
- b64_images.append(image_to_base64(img_path))
141
+ b64, content_type = image_to_base64(img_path)
142
+ batch_b64_data.append(b64)
143
+ batch_content_types.append(content_type)
131
144
 
132
- processed = self._processor.process_images(pil_images).to(self.device)
145
+ processed = self._hf_processor.process_images(pil_images).to(self._hf_device)
133
146
 
134
147
  with torch.no_grad():
135
- batch_embeddings = self._model(**processed)
148
+ batch_embeddings = self._hf_model(**processed)
136
149
  all_embeddings.append(batch_embeddings.cpu().float().numpy())
137
- all_b64_images.extend(b64_images)
150
+ all_b64_data.extend(batch_b64_data)
151
+ all_content_types.extend(batch_content_types)
138
152
 
139
153
  # Concatenate all batch embeddings
140
154
  final_embeddings = np.concatenate(all_embeddings, axis=0)
141
-
142
- return EmbeddingResponse(
143
- model_name=self.model_name,
144
- model_provider=self.provider_name,
145
- input_type="image",
146
- objects=[
147
- EmbeddingObject(
148
- embedding=final_embeddings[i],
149
- source_b64=all_b64_images[i]
150
- ) for i in range(len(final_embeddings))
151
- ]
155
+ logger.info(f"Successfully processed all {total_images} images")
156
+ return self._create_image_response(
157
+ final_embeddings, all_b64_data, all_content_types
152
158
  )
153
159
 
154
160
  except Exception as e:
161
+ logger.error(f"Failed to embed images: {e}")
155
162
  raise EmbeddingError(f"Failed to embed images: {e}") from e
156
-
157
- def embed_pdf(self, pdf_path: Path) -> EmbeddingResponse:
158
- """Generate embeddings for a PDF file using ColPali API."""
159
- images = pdf_to_images(pdf_path)
160
- return self.embed_image(images)
embedkit/utils.py CHANGED
@@ -1,37 +1,84 @@
1
+ import tempfile
2
+ import shutil
3
+ import logging
4
+ from contextlib import contextmanager
1
5
  from pdf2image import convert_from_path
2
6
  from pathlib import Path
3
7
  from .config import get_temp_dir
4
- from typing import Union
8
+ from typing import Union, List, Iterator, Callable, TypeVar, Any
5
9
 
10
+ logger = logging.getLogger(__name__)
6
11
 
7
- def pdf_to_images(pdf_path: Path) -> list[Path]:
8
- """Convert a PDF file to a list of images."""
9
- root_temp_dir = get_temp_dir()
10
- img_temp_dir = root_temp_dir / "images"
11
- img_temp_dir.mkdir(parents=True, exist_ok=True)
12
- images = convert_from_path(pdf_path=str(pdf_path), output_folder=str(img_temp_dir))
13
- image_paths = []
12
+ T = TypeVar("T")
14
13
 
15
- for i, image in enumerate(images):
16
- output_path = img_temp_dir / f"{pdf_path.stem}_{i}.png"
17
- if output_path.exists():
18
- output_path.unlink()
19
14
 
20
- image.save(output_path)
21
- image_paths.append(output_path)
22
- return image_paths
15
+ @contextmanager
16
+ def temporary_directory() -> Iterator[Path]:
17
+ """Create a temporary directory that is automatically cleaned up when done.
23
18
 
19
+ Yields:
20
+ Path: Path to the temporary directory
21
+ """
22
+ temp_dir = Path(tempfile.mkdtemp())
23
+ try:
24
+ yield temp_dir
25
+ finally:
26
+ shutil.rmtree(temp_dir)
27
+
28
+
29
+ def pdf_to_images(pdf_path: Path) -> List[Path]:
30
+ """Convert a PDF file to a list of images.
31
+
32
+ The images are stored in a temporary directory that will be automatically
33
+ cleaned up when the process exits.
34
+
35
+ Args:
36
+ pdf_path: Path to the PDF file
37
+
38
+ Returns:
39
+ List[Path]: List of paths to the generated images
40
+
41
+ Note:
42
+ The temporary files will be automatically cleaned up when the process exits.
43
+ Do not rely on these files persisting after the function returns.
44
+ """
45
+ with temporary_directory() as temp_dir:
46
+ images = convert_from_path(pdf_path=str(pdf_path), output_folder=str(temp_dir))
47
+ image_paths = []
48
+
49
+ for i, image in enumerate(images):
50
+ output_path = temp_dir / f"{pdf_path.stem}_{i}.png"
51
+ image.save(output_path)
52
+ final_path = Path(tempfile.mktemp(suffix=".png"))
53
+ shutil.move(output_path, final_path)
54
+ image_paths.append(final_path)
55
+
56
+ return image_paths
57
+
58
+
59
+ def image_to_base64(image_path: Union[str, Path]) -> tuple[str, str]:
60
+ """Convert an image to base64 and return the base64 data and content type.
24
61
 
25
- def image_to_base64(image_path: Union[str, Path]):
62
+ Args:
63
+ image_path: Path to the image file
64
+
65
+ Returns:
66
+ tuple[str, str]: (base64_data, content_type)
67
+
68
+ Raises:
69
+ ValueError: If the image cannot be read or has an unsupported format
70
+ """
26
71
  import base64
27
72
 
28
73
  try:
29
- base64_only = base64.b64encode(Path(image_path).read_bytes()).decode("utf-8")
74
+ base64_data = base64.b64encode(Path(image_path).read_bytes()).decode("utf-8")
30
75
  except Exception as e:
31
76
  raise ValueError(f"Failed to read image {image_path}: {e}") from e
32
77
 
33
78
  if isinstance(image_path, Path):
34
79
  image_path_str = str(image_path)
80
+ else:
81
+ image_path_str = image_path
35
82
 
36
83
  if image_path_str.lower().endswith(".png"):
37
84
  content_type = "image/png"
@@ -43,6 +90,53 @@ def image_to_base64(image_path: Union[str, Path]):
43
90
  raise ValueError(
44
91
  f"Unsupported image format for {image_path}; expected .png, .jpg, .jpeg, or .gif"
45
92
  )
46
- base64_image = f"data:{content_type};base64,{base64_only}"
47
93
 
48
- return base64_image
94
+ return base64_data, content_type
95
+
96
+
97
+ def with_pdf_cleanup(embed_func: Callable[..., T]) -> Callable[..., T]:
98
+ """Decorator to handle PDF to image conversion with automatic cleanup.
99
+
100
+ This decorator handles the common pattern of:
101
+ 1. Converting PDF to images
102
+ 2. Passing images to an embedding function
103
+ 3. Cleaning up temporary files
104
+
105
+ Args:
106
+ embed_func: Function that takes a list of image paths and returns embeddings
107
+
108
+ Returns:
109
+ Callable that takes a PDF path and returns embeddings
110
+ """
111
+
112
+ def wrapper(*args, **kwargs) -> T:
113
+ # First argument is self for instance methods
114
+ pdf_path = args[-1] if args else kwargs.get("pdf_path")
115
+ if not pdf_path:
116
+ raise ValueError(
117
+ "PDF path must be provided as the last positional argument or as 'pdf_path' keyword argument"
118
+ )
119
+
120
+ images = [] # Initialize images as empty list
121
+ try:
122
+ images = pdf_to_images(pdf_path)
123
+ # Call the original function with the images instead of pdf_path
124
+ if args:
125
+ # For instance methods, replace the last argument (pdf_path) with images
126
+ args = list(args)
127
+ args[-1] = images
128
+ else:
129
+ kwargs["pdf_path"] = images
130
+ return embed_func(*args, **kwargs)
131
+ finally:
132
+ # Clean up temporary files created by pdf_to_images
133
+ for img_path in images:
134
+ try:
135
+ if img_path.exists() and str(img_path).startswith(
136
+ tempfile.gettempdir()
137
+ ):
138
+ img_path.unlink()
139
+ except Exception as e:
140
+ logger.warning(f"Failed to clean up temporary file {img_path}: {e}")
141
+
142
+ return wrapper
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: embedkit
3
- Version: 0.1.4
3
+ Version: 0.1.6
4
4
  Summary: A simple toolkit for generating vector embeddings across multiple providers and models
5
5
  Author-email: JP Hwang <me@jphwang.com>
6
6
  License: MIT
@@ -22,7 +22,7 @@ Requires-Dist: colpali-engine<0.4.0,>=0.3.0
22
22
  Requires-Dist: pdf2image>=1.17.0
23
23
  Requires-Dist: pillow>=11.2.1
24
24
  Requires-Dist: torch<=2.5
25
- Requires-Dist: transformers
25
+ Requires-Dist: transformers>=4.46.2
26
26
  Description-Content-Type: text/markdown
27
27
 
28
28
  # EmbedKit
@@ -45,7 +45,7 @@ from embedkit.classes import Model, CohereInputType
45
45
 
46
46
  # Initialize with ColPali
47
47
  kit = EmbedKit.colpali(
48
- model=Model.ColPali.V1_3,
48
+ model=Model.ColPali.COLPALI_V1_3, # or COLSMOL_256M, COLSMOL_500M
49
49
  text_batch_size=16, # Optional: process text in batches of 16
50
50
  image_batch_size=8, # Optional: process images in batches of 8
51
51
  )
@@ -54,7 +54,7 @@ kit = EmbedKit.colpali(
54
54
  result = kit.embed_text("Hello world")
55
55
  print(result.model_provider)
56
56
  print(result.input_type)
57
- print(result.objects[0].embedding.shape)
57
+ print(result.objects[0].embedding.shape) # Returns 2D array for ColPali
58
58
  print(result.objects[0].source_b64)
59
59
 
60
60
  # Initialize with Cohere
@@ -70,7 +70,7 @@ kit = EmbedKit.cohere(
70
70
  result = kit.embed_text("Hello world")
71
71
  print(result.model_provider)
72
72
  print(result.input_type)
73
- print(result.objects[0].embedding.shape)
73
+ print(result.objects[0].embedding.shape) # Returns 1D array for Cohere
74
74
  print(result.objects[0].source_b64)
75
75
  ```
76
76
 
@@ -85,8 +85,8 @@ result = kit.embed_image(image_path)
85
85
 
86
86
  print(result.model_provider)
87
87
  print(result.input_type)
88
- print(result.objects[0].embedding.shape)
89
- print(result.objects[0].source_b64)
88
+ print(result.objects[0].embedding.shape) # 2D for ColPali, 1D for Cohere
89
+ print(result.objects[0].source_b64) # Base64 encoded image
90
90
  ```
91
91
 
92
92
  ### PDF Embeddings
@@ -100,8 +100,8 @@ result = kit.embed_pdf(pdf_path)
100
100
 
101
101
  print(result.model_provider)
102
102
  print(result.input_type)
103
- print(result.objects[0].embedding.shape)
104
- print(result.objects[0].source_b64)
103
+ print(result.objects[0].embedding.shape) # 2D for ColPali, 1D for Cohere
104
+ print(result.objects[0].source_b64) # Base64 encoded PDF page
105
105
  ```
106
106
 
107
107
  ## Response Format
@@ -116,17 +116,23 @@ class EmbeddingResponse:
116
116
  objects: List[EmbeddingObject]
117
117
 
118
118
  class EmbeddingObject:
119
- embedding: np.ndarray
120
- source_b64: Optional[str]
119
+ embedding: np.ndarray # 1D array for Cohere, 2D array for ColPali
120
+ source_b64: Optional[str] # Base64 encoded source for images and PDFs
121
121
  ```
122
122
 
123
123
  ## Supported Models
124
124
 
125
125
  ### ColPali
126
- - `Model.ColPali.V1_3`
126
+ - `Model.ColPali.COLPALI_V1_3`
127
+ - `Model.ColPali.COLSMOL_256M`
128
+ - `Model.ColPali.COLSMOL_500M`
127
129
 
128
130
  ### Cohere
129
131
  - `Model.Cohere.EMBED_V4_0`
132
+ - `Model.Cohere.EMBED_ENGLISH_V3_0`
133
+ - `Model.Cohere.EMBED_ENGLISH_LIGHT_V3_0`
134
+ - `Model.Cohere.EMBED_MULTILINGUAL_V3_0`
135
+ - `Model.Cohere.EMBED_MULTILINGUAL_LIGHT_V3_0`
130
136
 
131
137
  ## Requirements
132
138
 
@@ -0,0 +1,13 @@
1
+ embedkit/__init__.py,sha256=ahMyC4SjYLr3QpM39obwDKpJO_HQ4_kqvG4d7jD5XtA,4503
2
+ embedkit/base.py,sha256=FfyzCqG7azs4yJj4RU6QKQRy9_wfd3KUFidThplmEo0,3739
3
+ embedkit/classes.py,sha256=LI5egTsfcBglUqIxBrDV_ymA1eKcxOUHJeMAhKVtU_Y,582
4
+ embedkit/config.py,sha256=EVGODSKxQAr46bU8dyORFunsfRuj6dnvtSqa4MxUZCo,138
5
+ embedkit/models.py,sha256=fndxt6QoV7nuyeXijs39K3QLs1l8ATw2PX2X-64oKt8,575
6
+ embedkit/utils.py,sha256=xqIjwkWDgDr1mCtbb5UbZo_Ci50griubxYJq3pSBv34,4560
7
+ embedkit/providers/__init__.py,sha256=HaS-HNQabvhn9xLNZCq3VUqPCb7rGG4pvgvpKP4AXcw,201
8
+ embedkit/providers/cohere.py,sha256=B5hgOIY_rexgc0s_DWjZ-yuuOXaYLQV-jmrY-gdBhz8,4726
9
+ embedkit/providers/colpali.py,sha256=8LsV-kkBmNYzCGXWeZy-KLEyMOxTvzOAS4gnzN5q_f4,6260
10
+ embedkit-0.1.6.dist-info/METADATA,sha256=tf0SLNvH5OOYFudiMmgEALYE6fnV716WJO0pPZIQJPw,3871
11
+ embedkit-0.1.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
12
+ embedkit-0.1.6.dist-info/licenses/LICENSE,sha256=-g2Rad7b3rb2oVwOTwfMOIpscHT1zuaJoguamLRCBJs,1072
13
+ embedkit-0.1.6.dist-info/RECORD,,
@@ -1,13 +0,0 @@
1
- embedkit/__init__.py,sha256=O7C_e30VMrLliDamyNM_srdyuGz50PvGzBJo9NgePoU,4555
2
- embedkit/base.py,sha256=QkP-Ih7aYjIaxP5Hgp7MVypQlI5mhyOqLWKe0lOpIrQ,1344
3
- embedkit/classes.py,sha256=TlCl58Vo8p0q8SfJGINtBmCXY2yVFLtdlfQqBKrxivw,600
4
- embedkit/config.py,sha256=EVGODSKxQAr46bU8dyORFunsfRuj6dnvtSqa4MxUZCo,138
5
- embedkit/models.py,sha256=EBIYkyZeIhGaOPL-9bslHHdLaZ7qzOYLd0qxVZ7VX7w,226
6
- embedkit/utils.py,sha256=91BzzvbYSrUsWeW3CTAw3yK-M3S5FgQXov16gxffkUo,1572
7
- embedkit/providers/__init__.py,sha256=HaS-HNQabvhn9xLNZCq3VUqPCb7rGG4pvgvpKP4AXcw,201
8
- embedkit/providers/cohere.py,sha256=5-ux5UzpRqlRPm2PjgmRfcxH4NhRRjXa3cOZGWp0jDc,4880
9
- embedkit/providers/colpali.py,sha256=PUjTESyF0r__JdgfT22kK1DGrf7XwYC2EuyBelspTsc,5500
10
- embedkit-0.1.4.dist-info/METADATA,sha256=KPCFTWyeh_8unP1zBmVCUmA0c-yfXwxeMdWf9zO4CjA,3316
11
- embedkit-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
12
- embedkit-0.1.4.dist-info/licenses/LICENSE,sha256=-g2Rad7b3rb2oVwOTwfMOIpscHT1zuaJoguamLRCBJs,1072
13
- embedkit-0.1.4.dist-info/RECORD,,