embedkit 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
embedkit/__init__.py CHANGED
@@ -5,7 +5,6 @@ EmbedKit: A unified toolkit for generating vector embeddings.
5
5
 
6
6
  from typing import Union, List, Optional
7
7
  from pathlib import Path
8
- import numpy as np
9
8
 
10
9
  from .models import Model
11
10
  from .base import EmbeddingError, EmbeddingResponse
@@ -28,7 +27,7 @@ class EmbedKit:
28
27
  @classmethod
29
28
  def colpali(
30
29
  cls,
31
- model: Model = Model.ColPali.V1_3,
30
+ model: Model = Model.ColPali.COLPALI_V1_3,
32
31
  device: Optional[str] = None,
33
32
  text_batch_size: int = 32,
34
33
  image_batch_size: int = 8,
@@ -42,13 +41,13 @@ class EmbedKit:
42
41
  text_batch_size: Batch size for text embedding generation
43
42
  image_batch_size: Batch size for image embedding generation
44
43
  """
45
- if model == Model.ColPali.V1_3:
46
- model_name = "vidore/colpali-v1.3"
47
- else:
48
- raise ValueError(f"Unsupported model: {model}")
44
+ if not isinstance(model, Model.ColPali):
45
+ raise ValueError(
46
+ f"Unsupported model: {model}. Must be a Model.ColPali enum value."
47
+ )
49
48
 
50
49
  provider = ColPaliProvider(
51
- model_name=model_name,
50
+ model=model,
52
51
  device=device,
53
52
  text_batch_size=text_batch_size,
54
53
  image_batch_size=image_batch_size,
@@ -77,16 +76,15 @@ class EmbedKit:
77
76
  if not api_key:
78
77
  raise ValueError("API key is required")
79
78
 
80
- if model == Model.Cohere.EMBED_V4_0:
81
- model_name = "embed-v4.0"
82
- else:
79
+ if not isinstance(model, Model.Cohere):
83
80
  raise ValueError(f"Unsupported model: {model}")
84
81
 
85
82
  provider = CohereProvider(
86
- api_key=api_key, model_name=model_name,
83
+ api_key=api_key,
84
+ model=model,
87
85
  text_batch_size=text_batch_size,
88
86
  image_batch_size=image_batch_size,
89
- text_input_type=text_input_type
87
+ text_input_type=text_input_type,
90
88
  )
91
89
  return cls(provider)
92
90
 
embedkit/base.py CHANGED
@@ -7,11 +7,15 @@ from pathlib import Path
7
7
  import numpy as np
8
8
  from dataclasses import dataclass
9
9
 
10
+ from .models import Model
11
+ from .utils import with_pdf_cleanup
12
+
10
13
 
11
14
  @dataclass
12
15
  class EmbeddingObject:
13
16
  embedding: np.ndarray
14
17
  source_b64: str = None
18
+ source_content_type: str = None # e.g., "image/png", "image/jpeg"
15
19
 
16
20
 
17
21
  @dataclass
@@ -29,6 +33,67 @@ class EmbeddingResponse:
29
33
  class EmbeddingProvider(ABC):
30
34
  """Abstract base class for embedding providers."""
31
35
 
36
+ def __init__(
37
+ self,
38
+ model_name: str,
39
+ text_batch_size: int,
40
+ image_batch_size: int,
41
+ provider_name: str,
42
+ ):
43
+ self.model_name = model_name
44
+ self.provider_name = provider_name
45
+ self.text_batch_size = text_batch_size
46
+ self.image_batch_size = image_batch_size
47
+
48
+ def _normalize_text_input(self, texts: Union[str, List[str]]) -> List[str]:
49
+ """Normalize text input to a list of strings."""
50
+ if isinstance(texts, str):
51
+ return [texts]
52
+ return texts
53
+
54
+ def _normalize_image_input(
55
+ self, images: Union[Path, str, List[Union[Path, str]]]
56
+ ) -> List[Path]:
57
+ """Normalize image input to a list of Path objects."""
58
+ if isinstance(images, (str, Path)):
59
+ return [Path(images)]
60
+ return [Path(img) for img in images]
61
+
62
+ def _create_text_response(
63
+ self, embeddings: List[np.ndarray], input_type: str = "text"
64
+ ) -> EmbeddingResponse:
65
+ """Create a standardized text embedding response."""
66
+ return EmbeddingResponse(
67
+ model_name=self.model_name,
68
+ model_provider=self.provider_name,
69
+ input_type=input_type,
70
+ objects=[EmbeddingObject(embedding=e) for e in embeddings],
71
+ )
72
+
73
+ def _create_image_response(
74
+ self,
75
+ embeddings: List[np.ndarray],
76
+ b64_data: List[str],
77
+ content_types: List[str],
78
+ input_type: str = "image",
79
+ ) -> EmbeddingResponse:
80
+ """Create a standardized image embedding response."""
81
+ return EmbeddingResponse(
82
+ model_name=self.model_name,
83
+ model_provider=self.provider_name,
84
+ input_type=input_type,
85
+ objects=[
86
+ EmbeddingObject(
87
+ embedding=embedding,
88
+ source_b64=b64_data,
89
+ source_content_type=content_type,
90
+ )
91
+ for embedding, b64_data, content_type in zip(
92
+ embeddings, b64_data, content_types
93
+ )
94
+ ],
95
+ )
96
+
32
97
  @abstractmethod
33
98
  def embed_text(self, texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
34
99
  """Generate document text embeddings using the configured provider."""
@@ -41,10 +106,14 @@ class EmbeddingProvider(ABC):
41
106
  """Generate image embeddings using the configured provider."""
42
107
  pass
43
108
 
44
- @abstractmethod
45
- def embed_pdf(self, pdf: Union[Path, str]) -> EmbeddingResponse:
46
- """Generate image embeddings from PDFsusing the configured provider. Takes a single PDF file."""
47
- pass
109
+ def embed_pdf(self, pdf_path: Path) -> EmbeddingResponse:
110
+ """Generate embeddings for a PDF file."""
111
+ return self._embed_pdf_impl(pdf_path)
112
+
113
+ @with_pdf_cleanup
114
+ def _embed_pdf_impl(self, pdf_path: List[Path]) -> EmbeddingResponse:
115
+ """Internal implementation of PDF embedding with cleanup handled by decorator."""
116
+ return self.embed_image(pdf_path)
48
117
 
49
118
 
50
119
  class EmbeddingError(Exception):
embedkit/classes.py CHANGED
@@ -13,9 +13,4 @@ from . import EmbeddingResponse, EmbeddingError
13
13
  from .models import Model
14
14
  from .providers.cohere import CohereInputType
15
15
 
16
- __all__ = [
17
- "EmbeddingResponse",
18
- "EmbeddingError",
19
- "Model",
20
- "CohereInputType"
21
- ]
16
+ __all__ = ["EmbeddingResponse", "EmbeddingError", "Model", "CohereInputType"]
embedkit/models.py CHANGED
@@ -6,7 +6,13 @@ from enum import Enum
6
6
 
7
7
  class Model:
8
8
  class ColPali(Enum):
9
- V1_3 = "colpali-v1.3"
9
+ COLPALI_V1_3 = "vidore/colpali-v1.3"
10
+ COLSMOL_500M = "vidore/colSmol-500M"
11
+ COLSMOL_256M = "vidore/colSmol-256M"
10
12
 
11
13
  class Cohere(Enum):
12
14
  EMBED_V4_0 = "embed-v4.0"
15
+ EMBED_ENGLISH_V3_0 = "embed-english-v3.0"
16
+ EMBED_ENGLISH_LIGHT_V3_0 = "embed-english-light-v3.0"
17
+ EMBED_MULTILINGUAL_V3_0 = "embed-multilingual-v3.0"
18
+ EMBED_MULTILINGUAL_LIGHT_V3_0 = "embed-multilingual-light-v3.0"
@@ -6,7 +6,8 @@ from pathlib import Path
6
6
  import numpy as np
7
7
  from enum import Enum
8
8
 
9
- from ..utils import pdf_to_images, image_to_base64
9
+ from ..models import Model
10
+ from ..utils import image_to_base64
10
11
  from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResponse, EmbeddingObject
11
12
 
12
13
 
@@ -23,18 +24,20 @@ class CohereProvider(EmbeddingProvider):
23
24
  def __init__(
24
25
  self,
25
26
  api_key: str,
26
- model_name: str,
27
+ model: Model.Cohere,
27
28
  text_batch_size: int,
28
29
  image_batch_size: int,
29
30
  text_input_type: CohereInputType = CohereInputType.SEARCH_DOCUMENT,
30
31
  ):
32
+ super().__init__(
33
+ model_name=model.value,
34
+ text_batch_size=text_batch_size,
35
+ image_batch_size=image_batch_size,
36
+ provider_name="Cohere",
37
+ )
31
38
  self.api_key = api_key
32
- self.model_name = model_name
33
- self.text_batch_size = text_batch_size
34
- self.image_batch_size = image_batch_size
35
39
  self.input_type = text_input_type
36
40
  self._client = None
37
- self.provider_name = "Cohere"
38
41
 
39
42
  def _get_client(self):
40
43
  """Lazy load the Cohere client."""
@@ -54,9 +57,7 @@ class CohereProvider(EmbeddingProvider):
54
57
  def embed_text(self, texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
55
58
  """Generate text embeddings using the Cohere API."""
56
59
  client = self._get_client()
57
-
58
- if isinstance(texts, str):
59
- texts = [texts]
60
+ texts = self._normalize_text_input(texts)
60
61
 
61
62
  try:
62
63
  all_embeddings = []
@@ -72,16 +73,7 @@ class CohereProvider(EmbeddingProvider):
72
73
  )
73
74
  all_embeddings.extend(np.array(response.embeddings.float_))
74
75
 
75
- return EmbeddingResponse(
76
- model_name=self.model_name,
77
- model_provider=self.provider_name,
78
- input_type=self.input_type.value,
79
- objects=[
80
- EmbeddingObject(
81
- embedding=e,
82
- ) for e in all_embeddings
83
- ]
84
- )
76
+ return self._create_text_response(all_embeddings, self.input_type.value)
85
77
 
86
78
  except Exception as e:
87
79
  raise EmbeddingError(f"Failed to embed text with Cohere: {e}") from e
@@ -92,12 +84,7 @@ class CohereProvider(EmbeddingProvider):
92
84
  ) -> EmbeddingResponse:
93
85
  """Generate embeddings for images using Cohere API."""
94
86
  client = self._get_client()
95
- input_type = "image"
96
-
97
- if isinstance(images, (str, Path)):
98
- images = [Path(images)]
99
- else:
100
- images = [Path(img) for img in images]
87
+ images = self._normalize_image_input(images)
101
88
 
102
89
  try:
103
90
  all_embeddings = []
@@ -111,7 +98,11 @@ class CohereProvider(EmbeddingProvider):
111
98
  for image in batch_images:
112
99
  if not image.exists():
113
100
  raise EmbeddingError(f"Image not found: {image}")
114
- b64_images.append(image_to_base64(image))
101
+ b64_data, content_type = image_to_base64(image)
102
+ # Construct full data URI for API
103
+ data_uri = f"data:{content_type};base64,{b64_data}"
104
+ b64_images.append(data_uri)
105
+ all_b64_images.append((b64_data, content_type))
115
106
 
116
107
  response = client.embed(
117
108
  model=self.model_name,
@@ -121,24 +112,12 @@ class CohereProvider(EmbeddingProvider):
121
112
  )
122
113
 
123
114
  all_embeddings.extend(np.array(response.embeddings.float_))
124
- all_b64_images.extend(b64_images)
125
-
126
- return EmbeddingResponse(
127
- model_name=self.model_name,
128
- model_provider=self.provider_name,
129
- input_type=input_type,
130
- objects=[
131
- EmbeddingObject(
132
- embedding=all_embeddings[i],
133
- source_b64=all_b64_images[i]
134
- ) for i in range(len(all_embeddings))
135
- ]
115
+
116
+ return self._create_image_response(
117
+ all_embeddings,
118
+ [b64 for b64, _ in all_b64_images],
119
+ [content_type for _, content_type in all_b64_images],
136
120
  )
137
121
 
138
122
  except Exception as e:
139
123
  raise EmbeddingError(f"Failed to embed image with Cohere: {e}") from e
140
-
141
- def embed_pdf(self, pdf_path: Path) -> EmbeddingResponse:
142
- """Generate embeddings for a PDF file using Cohere API."""
143
- image_paths = pdf_to_images(pdf_path)
144
- return self.embed_image(image_paths)
@@ -8,8 +8,9 @@ import numpy as np
8
8
  import torch
9
9
  from PIL import Image
10
10
 
11
- from ..utils import pdf_to_images, image_to_base64
12
- from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResponse, EmbeddingObject
11
+ from ..models import Model
12
+ from ..utils import image_to_base64
13
+ from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResponse
13
14
 
14
15
  logger = logging.getLogger(__name__)
15
16
 
@@ -19,15 +20,17 @@ class ColPaliProvider(EmbeddingProvider):
19
20
 
20
21
  def __init__(
21
22
  self,
22
- model_name: str,
23
+ model: Model.ColPali,
23
24
  text_batch_size: int,
24
25
  image_batch_size: int,
25
26
  device: Optional[str] = None,
26
27
  ):
27
- self.model_name = model_name
28
- self.provider_name = "ColPali"
29
- self.text_batch_size = text_batch_size
30
- self.image_batch_size = image_batch_size
28
+ super().__init__(
29
+ model_name=model.value,
30
+ text_batch_size=text_batch_size,
31
+ image_batch_size=image_batch_size,
32
+ provider_name="ColPali",
33
+ )
31
34
 
32
35
  # Auto-detect device
33
36
  if device is None:
@@ -38,24 +41,43 @@ class ColPaliProvider(EmbeddingProvider):
38
41
  else:
39
42
  device = "cpu"
40
43
 
41
- self.device = device
42
- self._model = None
43
- self._processor = None
44
+ self._hf_device = device
45
+ self._hf_model = None
46
+ self._hf_processor = None
44
47
 
45
48
  def _load_model(self):
46
49
  """Lazy load the model."""
47
- if self._model is None:
50
+ if self._hf_model is None:
48
51
  try:
49
- from colpali_engine.models import ColPali, ColPaliProcessor
50
-
51
- self._model = ColPali.from_pretrained(
52
- self.model_name,
53
- torch_dtype=torch.bfloat16,
54
- device_map=self.device,
55
- ).eval()
56
-
57
- self._processor = ColPaliProcessor.from_pretrained(self.model_name)
58
- logger.info(f"Loaded ColPali model on {self.device}")
52
+ if self.model_name in [Model.ColPali.COLPALI_V1_3.value]:
53
+ from colpali_engine.models import ColPali, ColPaliProcessor
54
+
55
+ self._hf_model = ColPali.from_pretrained(
56
+ self.model_name,
57
+ torch_dtype=torch.bfloat16,
58
+ device_map=self._hf_device,
59
+ ).eval()
60
+
61
+ self._hf_processor = ColPaliProcessor.from_pretrained(self.model_name)
62
+
63
+ elif self.model_name in [
64
+ Model.ColPali.COLSMOL_500M.value,
65
+ Model.ColPali.COLSMOL_256M.value,
66
+ ]:
67
+ from colpali_engine.models import ColIdefics3, ColIdefics3Processor
68
+
69
+ self._hf_model = ColIdefics3.from_pretrained(
70
+ self.model_name,
71
+ torch_dtype=torch.bfloat16,
72
+ device_map=self._hf_device,
73
+ ).eval()
74
+ self._hf_processor = ColIdefics3Processor.from_pretrained(
75
+ self.model_name
76
+ )
77
+ else:
78
+ raise ValueError(f"Unable to load model for: {self.model_name}.")
79
+
80
+ logger.info(f"Loaded {self.model_name} on {self._hf_device}")
59
81
 
60
82
  except ImportError as e:
61
83
  raise EmbeddingError(
@@ -64,38 +86,26 @@ class ColPaliProvider(EmbeddingProvider):
64
86
  except Exception as e:
65
87
  raise EmbeddingError(f"Failed to load model: {e}") from e
66
88
 
67
- def embed_text(self, texts: Union[str, List[str]]) -> EmbeddingResponse:
89
+ def embed_text(self, texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
68
90
  """Generate embeddings for text inputs."""
69
91
  self._load_model()
70
-
71
- if isinstance(texts, str):
72
- texts = [texts]
92
+ texts = self._normalize_text_input(texts)
73
93
 
74
94
  try:
75
95
  # Process texts in batches
76
- all_embeddings = []
96
+ all_embeddings: List[np.ndarray] = []
77
97
 
78
98
  for i in range(0, len(texts), self.text_batch_size):
79
99
  batch_texts = texts[i : i + self.text_batch_size]
80
- processed = self._processor.process_queries(batch_texts).to(self.device)
100
+ processed = self._hf_processor.process_queries(batch_texts).to(self._hf_device)
81
101
 
82
102
  with torch.no_grad():
83
- batch_embeddings = self._model(**processed)
103
+ batch_embeddings = self._hf_model(**processed)
84
104
  all_embeddings.append(batch_embeddings.cpu().float().numpy())
85
105
 
86
106
  # Concatenate all batch embeddings
87
107
  final_embeddings = np.concatenate(all_embeddings, axis=0)
88
-
89
- return EmbeddingResponse(
90
- model_name=self.model_name,
91
- model_provider=self.provider_name,
92
- input_type="text",
93
- objects=[
94
- EmbeddingObject(
95
- embedding=e,
96
- ) for e in final_embeddings
97
- ]
98
- )
108
+ return self._create_text_response(final_embeddings)
99
109
 
100
110
  except Exception as e:
101
111
  raise EmbeddingError(f"Failed to embed text: {e}") from e
@@ -105,21 +115,19 @@ class ColPaliProvider(EmbeddingProvider):
105
115
  ) -> EmbeddingResponse:
106
116
  """Generate embeddings for images."""
107
117
  self._load_model()
108
-
109
- if isinstance(images, (str, Path)):
110
- images = [Path(images)]
111
- else:
112
- images = [Path(img) for img in images]
118
+ images = self._normalize_image_input(images)
113
119
 
114
120
  try:
115
121
  # Process images in batches
116
- all_embeddings = []
117
- all_b64_images = []
122
+ all_embeddings: List[np.ndarray] = []
123
+ all_b64_data: List[str] = []
124
+ all_content_types: List[str] = []
118
125
 
119
126
  for i in range(0, len(images), self.image_batch_size):
120
127
  batch_images = images[i : i + self.image_batch_size]
121
128
  pil_images = []
122
- b64_images = []
129
+ batch_b64_data = []
130
+ batch_content_types = []
123
131
 
124
132
  for img_path in batch_images:
125
133
  if not img_path.exists():
@@ -127,34 +135,23 @@ class ColPaliProvider(EmbeddingProvider):
127
135
 
128
136
  with Image.open(img_path) as img:
129
137
  pil_images.append(img.convert("RGB"))
130
- b64_images.append(image_to_base64(img_path))
138
+ b64, content_type = image_to_base64(img_path)
139
+ batch_b64_data.append(b64)
140
+ batch_content_types.append(content_type)
131
141
 
132
- processed = self._processor.process_images(pil_images).to(self.device)
142
+ processed = self._hf_processor.process_images(pil_images).to(self._hf_device)
133
143
 
134
144
  with torch.no_grad():
135
- batch_embeddings = self._model(**processed)
145
+ batch_embeddings = self._hf_model(**processed)
136
146
  all_embeddings.append(batch_embeddings.cpu().float().numpy())
137
- all_b64_images.extend(b64_images)
147
+ all_b64_data.extend(batch_b64_data)
148
+ all_content_types.extend(batch_content_types)
138
149
 
139
150
  # Concatenate all batch embeddings
140
151
  final_embeddings = np.concatenate(all_embeddings, axis=0)
141
-
142
- return EmbeddingResponse(
143
- model_name=self.model_name,
144
- model_provider=self.provider_name,
145
- input_type="image",
146
- objects=[
147
- EmbeddingObject(
148
- embedding=final_embeddings[i],
149
- source_b64=all_b64_images[i]
150
- ) for i in range(len(final_embeddings))
151
- ]
152
+ return self._create_image_response(
153
+ final_embeddings, all_b64_data, all_content_types
152
154
  )
153
155
 
154
156
  except Exception as e:
155
157
  raise EmbeddingError(f"Failed to embed images: {e}") from e
156
-
157
- def embed_pdf(self, pdf_path: Path) -> EmbeddingResponse:
158
- """Generate embeddings for a PDF file using ColPali API."""
159
- images = pdf_to_images(pdf_path)
160
- return self.embed_image(images)
embedkit/utils.py CHANGED
@@ -1,37 +1,84 @@
1
+ import tempfile
2
+ import shutil
3
+ import logging
4
+ from contextlib import contextmanager
1
5
  from pdf2image import convert_from_path
2
6
  from pathlib import Path
3
7
  from .config import get_temp_dir
4
- from typing import Union
8
+ from typing import Union, List, Iterator, Callable, TypeVar, Any
5
9
 
10
+ logger = logging.getLogger(__name__)
6
11
 
7
- def pdf_to_images(pdf_path: Path) -> list[Path]:
8
- """Convert a PDF file to a list of images."""
9
- root_temp_dir = get_temp_dir()
10
- img_temp_dir = root_temp_dir / "images"
11
- img_temp_dir.mkdir(parents=True, exist_ok=True)
12
- images = convert_from_path(pdf_path=str(pdf_path), output_folder=str(img_temp_dir))
13
- image_paths = []
12
+ T = TypeVar("T")
14
13
 
15
- for i, image in enumerate(images):
16
- output_path = img_temp_dir / f"{pdf_path.stem}_{i}.png"
17
- if output_path.exists():
18
- output_path.unlink()
19
14
 
20
- image.save(output_path)
21
- image_paths.append(output_path)
22
- return image_paths
15
+ @contextmanager
16
+ def temporary_directory() -> Iterator[Path]:
17
+ """Create a temporary directory that is automatically cleaned up when done.
23
18
 
19
+ Yields:
20
+ Path: Path to the temporary directory
21
+ """
22
+ temp_dir = Path(tempfile.mkdtemp())
23
+ try:
24
+ yield temp_dir
25
+ finally:
26
+ shutil.rmtree(temp_dir)
27
+
28
+
29
+ def pdf_to_images(pdf_path: Path) -> List[Path]:
30
+ """Convert a PDF file to a list of images.
31
+
32
+ The images are stored in a temporary directory that will be automatically
33
+ cleaned up when the process exits.
34
+
35
+ Args:
36
+ pdf_path: Path to the PDF file
37
+
38
+ Returns:
39
+ List[Path]: List of paths to the generated images
40
+
41
+ Note:
42
+ The temporary files will be automatically cleaned up when the process exits.
43
+ Do not rely on these files persisting after the function returns.
44
+ """
45
+ with temporary_directory() as temp_dir:
46
+ images = convert_from_path(pdf_path=str(pdf_path), output_folder=str(temp_dir))
47
+ image_paths = []
48
+
49
+ for i, image in enumerate(images):
50
+ output_path = temp_dir / f"{pdf_path.stem}_{i}.png"
51
+ image.save(output_path)
52
+ final_path = Path(tempfile.mktemp(suffix=".png"))
53
+ shutil.move(output_path, final_path)
54
+ image_paths.append(final_path)
55
+
56
+ return image_paths
57
+
58
+
59
+ def image_to_base64(image_path: Union[str, Path]) -> tuple[str, str]:
60
+ """Convert an image to base64 and return the base64 data and content type.
24
61
 
25
- def image_to_base64(image_path: Union[str, Path]):
62
+ Args:
63
+ image_path: Path to the image file
64
+
65
+ Returns:
66
+ tuple[str, str]: (base64_data, content_type)
67
+
68
+ Raises:
69
+ ValueError: If the image cannot be read or has an unsupported format
70
+ """
26
71
  import base64
27
72
 
28
73
  try:
29
- base64_only = base64.b64encode(Path(image_path).read_bytes()).decode("utf-8")
74
+ base64_data = base64.b64encode(Path(image_path).read_bytes()).decode("utf-8")
30
75
  except Exception as e:
31
76
  raise ValueError(f"Failed to read image {image_path}: {e}") from e
32
77
 
33
78
  if isinstance(image_path, Path):
34
79
  image_path_str = str(image_path)
80
+ else:
81
+ image_path_str = image_path
35
82
 
36
83
  if image_path_str.lower().endswith(".png"):
37
84
  content_type = "image/png"
@@ -43,6 +90,52 @@ def image_to_base64(image_path: Union[str, Path]):
43
90
  raise ValueError(
44
91
  f"Unsupported image format for {image_path}; expected .png, .jpg, .jpeg, or .gif"
45
92
  )
46
- base64_image = f"data:{content_type};base64,{base64_only}"
47
93
 
48
- return base64_image
94
+ return base64_data, content_type
95
+
96
+
97
+ def with_pdf_cleanup(embed_func: Callable[..., T]) -> Callable[..., T]:
98
+ """Decorator to handle PDF to image conversion with automatic cleanup.
99
+
100
+ This decorator handles the common pattern of:
101
+ 1. Converting PDF to images
102
+ 2. Passing images to an embedding function
103
+ 3. Cleaning up temporary files
104
+
105
+ Args:
106
+ embed_func: Function that takes a list of image paths and returns embeddings
107
+
108
+ Returns:
109
+ Callable that takes a PDF path and returns embeddings
110
+ """
111
+
112
+ def wrapper(*args, **kwargs) -> T:
113
+ # First argument is self for instance methods
114
+ pdf_path = args[-1] if args else kwargs.get("pdf_path")
115
+ if not pdf_path:
116
+ raise ValueError(
117
+ "PDF path must be provided as the last positional argument or as 'pdf_path' keyword argument"
118
+ )
119
+
120
+ try:
121
+ images = pdf_to_images(pdf_path)
122
+ # Call the original function with the images instead of pdf_path
123
+ if args:
124
+ # For instance methods, replace the last argument (pdf_path) with images
125
+ args = list(args)
126
+ args[-1] = images
127
+ else:
128
+ kwargs["pdf_path"] = images
129
+ return embed_func(*args, **kwargs)
130
+ finally:
131
+ # Clean up temporary files created by pdf_to_images
132
+ for img_path in images:
133
+ try:
134
+ if img_path.exists() and str(img_path).startswith(
135
+ tempfile.gettempdir()
136
+ ):
137
+ img_path.unlink()
138
+ except Exception as e:
139
+ logger.warning(f"Failed to clean up temporary file {img_path}: {e}")
140
+
141
+ return wrapper
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: embedkit
3
- Version: 0.1.4
3
+ Version: 0.1.5
4
4
  Summary: A simple toolkit for generating vector embeddings across multiple providers and models
5
5
  Author-email: JP Hwang <me@jphwang.com>
6
6
  License: MIT
@@ -22,7 +22,7 @@ Requires-Dist: colpali-engine<0.4.0,>=0.3.0
22
22
  Requires-Dist: pdf2image>=1.17.0
23
23
  Requires-Dist: pillow>=11.2.1
24
24
  Requires-Dist: torch<=2.5
25
- Requires-Dist: transformers
25
+ Requires-Dist: transformers>=4.46.2
26
26
  Description-Content-Type: text/markdown
27
27
 
28
28
  # EmbedKit
@@ -45,7 +45,7 @@ from embedkit.classes import Model, CohereInputType
45
45
 
46
46
  # Initialize with ColPali
47
47
  kit = EmbedKit.colpali(
48
- model=Model.ColPali.V1_3,
48
+ model=Model.ColPali.COLPALI_V1_3, # or COLSMOL_256M, COLSMOL_500M
49
49
  text_batch_size=16, # Optional: process text in batches of 16
50
50
  image_batch_size=8, # Optional: process images in batches of 8
51
51
  )
@@ -54,7 +54,7 @@ kit = EmbedKit.colpali(
54
54
  result = kit.embed_text("Hello world")
55
55
  print(result.model_provider)
56
56
  print(result.input_type)
57
- print(result.objects[0].embedding.shape)
57
+ print(result.objects[0].embedding.shape) # Returns 2D array for ColPali
58
58
  print(result.objects[0].source_b64)
59
59
 
60
60
  # Initialize with Cohere
@@ -70,7 +70,7 @@ kit = EmbedKit.cohere(
70
70
  result = kit.embed_text("Hello world")
71
71
  print(result.model_provider)
72
72
  print(result.input_type)
73
- print(result.objects[0].embedding.shape)
73
+ print(result.objects[0].embedding.shape) # Returns 1D array for Cohere
74
74
  print(result.objects[0].source_b64)
75
75
  ```
76
76
 
@@ -85,8 +85,8 @@ result = kit.embed_image(image_path)
85
85
 
86
86
  print(result.model_provider)
87
87
  print(result.input_type)
88
- print(result.objects[0].embedding.shape)
89
- print(result.objects[0].source_b64)
88
+ print(result.objects[0].embedding.shape) # 2D for ColPali, 1D for Cohere
89
+ print(result.objects[0].source_b64) # Base64 encoded image
90
90
  ```
91
91
 
92
92
  ### PDF Embeddings
@@ -100,8 +100,8 @@ result = kit.embed_pdf(pdf_path)
100
100
 
101
101
  print(result.model_provider)
102
102
  print(result.input_type)
103
- print(result.objects[0].embedding.shape)
104
- print(result.objects[0].source_b64)
103
+ print(result.objects[0].embedding.shape) # 2D for ColPali, 1D for Cohere
104
+ print(result.objects[0].source_b64) # Base64 encoded PDF page
105
105
  ```
106
106
 
107
107
  ## Response Format
@@ -116,17 +116,23 @@ class EmbeddingResponse:
116
116
  objects: List[EmbeddingObject]
117
117
 
118
118
  class EmbeddingObject:
119
- embedding: np.ndarray
120
- source_b64: Optional[str]
119
+ embedding: np.ndarray # 1D array for Cohere, 2D array for ColPali
120
+ source_b64: Optional[str] # Base64 encoded source for images and PDFs
121
121
  ```
122
122
 
123
123
  ## Supported Models
124
124
 
125
125
  ### ColPali
126
- - `Model.ColPali.V1_3`
126
+ - `Model.ColPali.COLPALI_V1_3`
127
+ - `Model.ColPali.COLSMOL_256M`
128
+ - `Model.ColPali.COLSMOL_500M`
127
129
 
128
130
  ### Cohere
129
131
  - `Model.Cohere.EMBED_V4_0`
132
+ - `Model.Cohere.EMBED_ENGLISH_V3_0`
133
+ - `Model.Cohere.EMBED_ENGLISH_LIGHT_V3_0`
134
+ - `Model.Cohere.EMBED_MULTILINGUAL_V3_0`
135
+ - `Model.Cohere.EMBED_MULTILINGUAL_LIGHT_V3_0`
130
136
 
131
137
  ## Requirements
132
138
 
@@ -0,0 +1,13 @@
1
+ embedkit/__init__.py,sha256=ahMyC4SjYLr3QpM39obwDKpJO_HQ4_kqvG4d7jD5XtA,4503
2
+ embedkit/base.py,sha256=FfyzCqG7azs4yJj4RU6QKQRy9_wfd3KUFidThplmEo0,3739
3
+ embedkit/classes.py,sha256=LI5egTsfcBglUqIxBrDV_ymA1eKcxOUHJeMAhKVtU_Y,582
4
+ embedkit/config.py,sha256=EVGODSKxQAr46bU8dyORFunsfRuj6dnvtSqa4MxUZCo,138
5
+ embedkit/models.py,sha256=fndxt6QoV7nuyeXijs39K3QLs1l8ATw2PX2X-64oKt8,575
6
+ embedkit/utils.py,sha256=GV2dqNdI8XKP3iRD6zyoY5jlAcqVt1W5pxh3Rmn81Cc,4505
7
+ embedkit/providers/__init__.py,sha256=HaS-HNQabvhn9xLNZCq3VUqPCb7rGG4pvgvpKP4AXcw,201
8
+ embedkit/providers/cohere.py,sha256=sD9wsh8ifUYnuBbOeMklvBo49rlnsf-WU568W_fu8dM,4275
9
+ embedkit/providers/colpali.py,sha256=swdn8xHyn-Xob3POQFEbhwcjR_oLp7uVUMXc2gOggK8,5845
10
+ embedkit-0.1.5.dist-info/METADATA,sha256=hlUZDvuL7FXa1XjCZdZSftiT9f4rGc7DUoYkodAG94g,3871
11
+ embedkit-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
12
+ embedkit-0.1.5.dist-info/licenses/LICENSE,sha256=-g2Rad7b3rb2oVwOTwfMOIpscHT1zuaJoguamLRCBJs,1072
13
+ embedkit-0.1.5.dist-info/RECORD,,
@@ -1,13 +0,0 @@
1
- embedkit/__init__.py,sha256=O7C_e30VMrLliDamyNM_srdyuGz50PvGzBJo9NgePoU,4555
2
- embedkit/base.py,sha256=QkP-Ih7aYjIaxP5Hgp7MVypQlI5mhyOqLWKe0lOpIrQ,1344
3
- embedkit/classes.py,sha256=TlCl58Vo8p0q8SfJGINtBmCXY2yVFLtdlfQqBKrxivw,600
4
- embedkit/config.py,sha256=EVGODSKxQAr46bU8dyORFunsfRuj6dnvtSqa4MxUZCo,138
5
- embedkit/models.py,sha256=EBIYkyZeIhGaOPL-9bslHHdLaZ7qzOYLd0qxVZ7VX7w,226
6
- embedkit/utils.py,sha256=91BzzvbYSrUsWeW3CTAw3yK-M3S5FgQXov16gxffkUo,1572
7
- embedkit/providers/__init__.py,sha256=HaS-HNQabvhn9xLNZCq3VUqPCb7rGG4pvgvpKP4AXcw,201
8
- embedkit/providers/cohere.py,sha256=5-ux5UzpRqlRPm2PjgmRfcxH4NhRRjXa3cOZGWp0jDc,4880
9
- embedkit/providers/colpali.py,sha256=PUjTESyF0r__JdgfT22kK1DGrf7XwYC2EuyBelspTsc,5500
10
- embedkit-0.1.4.dist-info/METADATA,sha256=KPCFTWyeh_8unP1zBmVCUmA0c-yfXwxeMdWf9zO4CjA,3316
11
- embedkit-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
12
- embedkit-0.1.4.dist-info/licenses/LICENSE,sha256=-g2Rad7b3rb2oVwOTwfMOIpscHT1zuaJoguamLRCBJs,1072
13
- embedkit-0.1.4.dist-info/RECORD,,