embedkit 0.1.4__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. {embedkit-0.1.4 → embedkit-0.1.6}/PKG-INFO +18 -12
  2. {embedkit-0.1.4 → embedkit-0.1.6}/README.md +16 -10
  3. {embedkit-0.1.4 → embedkit-0.1.6}/main.py +36 -28
  4. {embedkit-0.1.4 → embedkit-0.1.6}/pyproject.toml +2 -2
  5. {embedkit-0.1.4 → embedkit-0.1.6}/src/embedkit/__init__.py +10 -12
  6. embedkit-0.1.6/src/embedkit/base.py +122 -0
  7. {embedkit-0.1.4 → embedkit-0.1.6}/src/embedkit/classes.py +1 -6
  8. embedkit-0.1.6/src/embedkit/models.py +18 -0
  9. {embedkit-0.1.4 → embedkit-0.1.6}/src/embedkit/providers/cohere.py +31 -44
  10. embedkit-0.1.6/src/embedkit/providers/colpali.py +162 -0
  11. embedkit-0.1.6/src/embedkit/utils.py +142 -0
  12. embedkit-0.1.6/tests/fixtures/2407.01449v6_p1_p5.pdf +0 -0
  13. {embedkit-0.1.4 → embedkit-0.1.6}/tests/test_embedkit.py +2 -2
  14. embedkit-0.1.6/tests/test_utils.py +52 -0
  15. {embedkit-0.1.4 → embedkit-0.1.6}/uv.lock +2 -2
  16. embedkit-0.1.4/src/embedkit/base.py +0 -53
  17. embedkit-0.1.4/src/embedkit/models.py +0 -12
  18. embedkit-0.1.4/src/embedkit/providers/colpali.py +0 -160
  19. embedkit-0.1.4/src/embedkit/utils.py +0 -48
  20. {embedkit-0.1.4 → embedkit-0.1.6}/.gitignore +0 -0
  21. {embedkit-0.1.4 → embedkit-0.1.6}/.python-version +0 -0
  22. {embedkit-0.1.4 → embedkit-0.1.6}/LICENSE +0 -0
  23. {embedkit-0.1.4 → embedkit-0.1.6}/src/embedkit/config.py +0 -0
  24. {embedkit-0.1.4 → embedkit-0.1.6}/src/embedkit/providers/__init__.py +0 -0
  25. {embedkit-0.1.4 → embedkit-0.1.6}/tests/conftest.py +0 -0
  26. {embedkit-0.1.4 → embedkit-0.1.6}/tests/fixtures/2407.01449v6_p1.pdf +0 -0
  27. {embedkit-0.1.4 → embedkit-0.1.6}/tests/fixtures/2407.01449v6_p1.png +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: embedkit
3
- Version: 0.1.4
3
+ Version: 0.1.6
4
4
  Summary: A simple toolkit for generating vector embeddings across multiple providers and models
5
5
  Author-email: JP Hwang <me@jphwang.com>
6
6
  License: MIT
@@ -22,7 +22,7 @@ Requires-Dist: colpali-engine<0.4.0,>=0.3.0
22
22
  Requires-Dist: pdf2image>=1.17.0
23
23
  Requires-Dist: pillow>=11.2.1
24
24
  Requires-Dist: torch<=2.5
25
- Requires-Dist: transformers
25
+ Requires-Dist: transformers>=4.46.2
26
26
  Description-Content-Type: text/markdown
27
27
 
28
28
  # EmbedKit
@@ -45,7 +45,7 @@ from embedkit.classes import Model, CohereInputType
45
45
 
46
46
  # Initialize with ColPali
47
47
  kit = EmbedKit.colpali(
48
- model=Model.ColPali.V1_3,
48
+ model=Model.ColPali.COLPALI_V1_3, # or COLSMOL_256M, COLSMOL_500M
49
49
  text_batch_size=16, # Optional: process text in batches of 16
50
50
  image_batch_size=8, # Optional: process images in batches of 8
51
51
  )
@@ -54,7 +54,7 @@ kit = EmbedKit.colpali(
54
54
  result = kit.embed_text("Hello world")
55
55
  print(result.model_provider)
56
56
  print(result.input_type)
57
- print(result.objects[0].embedding.shape)
57
+ print(result.objects[0].embedding.shape) # Returns 2D array for ColPali
58
58
  print(result.objects[0].source_b64)
59
59
 
60
60
  # Initialize with Cohere
@@ -70,7 +70,7 @@ kit = EmbedKit.cohere(
70
70
  result = kit.embed_text("Hello world")
71
71
  print(result.model_provider)
72
72
  print(result.input_type)
73
- print(result.objects[0].embedding.shape)
73
+ print(result.objects[0].embedding.shape) # Returns 1D array for Cohere
74
74
  print(result.objects[0].source_b64)
75
75
  ```
76
76
 
@@ -85,8 +85,8 @@ result = kit.embed_image(image_path)
85
85
 
86
86
  print(result.model_provider)
87
87
  print(result.input_type)
88
- print(result.objects[0].embedding.shape)
89
- print(result.objects[0].source_b64)
88
+ print(result.objects[0].embedding.shape) # 2D for ColPali, 1D for Cohere
89
+ print(result.objects[0].source_b64) # Base64 encoded image
90
90
  ```
91
91
 
92
92
  ### PDF Embeddings
@@ -100,8 +100,8 @@ result = kit.embed_pdf(pdf_path)
100
100
 
101
101
  print(result.model_provider)
102
102
  print(result.input_type)
103
- print(result.objects[0].embedding.shape)
104
- print(result.objects[0].source_b64)
103
+ print(result.objects[0].embedding.shape) # 2D for ColPali, 1D for Cohere
104
+ print(result.objects[0].source_b64) # Base64 encoded PDF page
105
105
  ```
106
106
 
107
107
  ## Response Format
@@ -116,17 +116,23 @@ class EmbeddingResponse:
116
116
  objects: List[EmbeddingObject]
117
117
 
118
118
  class EmbeddingObject:
119
- embedding: np.ndarray
120
- source_b64: Optional[str]
119
+ embedding: np.ndarray # 1D array for Cohere, 2D array for ColPali
120
+ source_b64: Optional[str] # Base64 encoded source for images and PDFs
121
121
  ```
122
122
 
123
123
  ## Supported Models
124
124
 
125
125
  ### ColPali
126
- - `Model.ColPali.V1_3`
126
+ - `Model.ColPali.COLPALI_V1_3`
127
+ - `Model.ColPali.COLSMOL_256M`
128
+ - `Model.ColPali.COLSMOL_500M`
127
129
 
128
130
  ### Cohere
129
131
  - `Model.Cohere.EMBED_V4_0`
132
+ - `Model.Cohere.EMBED_ENGLISH_V3_0`
133
+ - `Model.Cohere.EMBED_ENGLISH_LIGHT_V3_0`
134
+ - `Model.Cohere.EMBED_MULTILINGUAL_V3_0`
135
+ - `Model.Cohere.EMBED_MULTILINGUAL_LIGHT_V3_0`
130
136
 
131
137
  ## Requirements
132
138
 
@@ -18,7 +18,7 @@ from embedkit.classes import Model, CohereInputType
18
18
 
19
19
  # Initialize with ColPali
20
20
  kit = EmbedKit.colpali(
21
- model=Model.ColPali.V1_3,
21
+ model=Model.ColPali.COLPALI_V1_3, # or COLSMOL_256M, COLSMOL_500M
22
22
  text_batch_size=16, # Optional: process text in batches of 16
23
23
  image_batch_size=8, # Optional: process images in batches of 8
24
24
  )
@@ -27,7 +27,7 @@ kit = EmbedKit.colpali(
27
27
  result = kit.embed_text("Hello world")
28
28
  print(result.model_provider)
29
29
  print(result.input_type)
30
- print(result.objects[0].embedding.shape)
30
+ print(result.objects[0].embedding.shape) # Returns 2D array for ColPali
31
31
  print(result.objects[0].source_b64)
32
32
 
33
33
  # Initialize with Cohere
@@ -43,7 +43,7 @@ kit = EmbedKit.cohere(
43
43
  result = kit.embed_text("Hello world")
44
44
  print(result.model_provider)
45
45
  print(result.input_type)
46
- print(result.objects[0].embedding.shape)
46
+ print(result.objects[0].embedding.shape) # Returns 1D array for Cohere
47
47
  print(result.objects[0].source_b64)
48
48
  ```
49
49
 
@@ -58,8 +58,8 @@ result = kit.embed_image(image_path)
58
58
 
59
59
  print(result.model_provider)
60
60
  print(result.input_type)
61
- print(result.objects[0].embedding.shape)
62
- print(result.objects[0].source_b64)
61
+ print(result.objects[0].embedding.shape) # 2D for ColPali, 1D for Cohere
62
+ print(result.objects[0].source_b64) # Base64 encoded image
63
63
  ```
64
64
 
65
65
  ### PDF Embeddings
@@ -73,8 +73,8 @@ result = kit.embed_pdf(pdf_path)
73
73
 
74
74
  print(result.model_provider)
75
75
  print(result.input_type)
76
- print(result.objects[0].embedding.shape)
77
- print(result.objects[0].source_b64)
76
+ print(result.objects[0].embedding.shape) # 2D for ColPali, 1D for Cohere
77
+ print(result.objects[0].source_b64) # Base64 encoded PDF page
78
78
  ```
79
79
 
80
80
  ## Response Format
@@ -89,17 +89,23 @@ class EmbeddingResponse:
89
89
  objects: List[EmbeddingObject]
90
90
 
91
91
  class EmbeddingObject:
92
- embedding: np.ndarray
93
- source_b64: Optional[str]
92
+ embedding: np.ndarray # 1D array for Cohere, 2D array for ColPali
93
+ source_b64: Optional[str] # Base64 encoded source for images and PDFs
94
94
  ```
95
95
 
96
96
  ## Supported Models
97
97
 
98
98
  ### ColPali
99
- - `Model.ColPali.V1_3`
99
+ - `Model.ColPali.COLPALI_V1_3`
100
+ - `Model.ColPali.COLSMOL_256M`
101
+ - `Model.ColPali.COLSMOL_500M`
100
102
 
101
103
  ### Cohere
102
104
  - `Model.Cohere.EMBED_V4_0`
105
+ - `Model.Cohere.EMBED_ENGLISH_V3_0`
106
+ - `Model.Cohere.EMBED_ENGLISH_LIGHT_V3_0`
107
+ - `Model.Cohere.EMBED_MULTILINGUAL_V3_0`
108
+ - `Model.Cohere.EMBED_MULTILINGUAL_LIGHT_V3_0`
103
109
 
104
110
  ## Requirements
105
111
 
@@ -32,30 +32,7 @@ def get_sample_image() -> Path:
32
32
  sample_image = get_sample_image()
33
33
 
34
34
  sample_pdf = Path("tests/fixtures/2407.01449v6_p1.pdf")
35
- long_pdf = Path("tmp/2407.01449v6.pdf")
36
-
37
- kit = EmbedKit.colpali(model=Model.ColPali.V1_3, text_batch_size=16, image_batch_size=8)
38
-
39
- results = kit.embed_text("Hello world")
40
- assert len(results.objects) == 1
41
- assert len(results.objects[0].embedding.shape) == 2
42
- assert results.objects[0].source_b64 == None
43
-
44
- results = kit.embed_image(sample_image)
45
- assert len(results.objects) == 1
46
- assert len(results.objects[0].embedding.shape) == 2
47
- assert type(results.objects[0].source_b64) == str
48
-
49
- results = kit.embed_pdf(sample_pdf)
50
- assert len(results.objects) == 1
51
- assert len(results.objects[0].embedding.shape) == 2
52
- assert type(results.objects[0].source_b64) == str
53
-
54
- # results = kit.embed_pdf(long_pdf)
55
- # assert len(results.objects) == 26
56
- # assert len(results.objects[0].embedding.shape) == 2
57
- # assert type(results.objects[0].source_b64) == str
58
-
35
+ longer_pdf = Path("tests/fixtures/2407.01449v6_p1_p5.pdf")
59
36
 
60
37
  kit = EmbedKit.cohere(
61
38
  model=Model.Cohere.EMBED_V4_0,
@@ -65,6 +42,7 @@ kit = EmbedKit.cohere(
65
42
  text_input_type=CohereInputType.SEARCH_QUERY,
66
43
  )
67
44
 
45
+ print(f"Trying out Cohere")
68
46
  results = kit.embed_text("Hello world")
69
47
  assert len(results.objects) == 1
70
48
  assert len(results.objects[0].embedding.shape) == 1
@@ -93,7 +71,37 @@ assert len(results.objects) == 1
93
71
  assert len(results.objects[0].embedding.shape) == 1
94
72
  assert type(results.objects[0].source_b64) == str
95
73
 
96
- # results = kit.embed_pdf(long_pdf)
97
- # assert len(results.objects) == 1
98
- # assert len(results.objects[0].embedding.shape) == 1
99
- # assert type(results.objects[0].source_b64) == str
74
+ results = kit.embed_pdf(longer_pdf)
75
+ assert len(results.objects) == 5
76
+ assert len(results.objects[0].embedding.shape) == 1
77
+ assert type(results.objects[0].source_b64) == str
78
+
79
+ for colpali_model in [
80
+ Model.ColPali.COLSMOL_256M,
81
+ Model.ColPali.COLSMOL_500M,
82
+ Model.ColPali.COLPALI_V1_3,
83
+ ]:
84
+ print(f"Trying out {colpali_model}")
85
+ kit = EmbedKit.colpali(
86
+ model=colpali_model, text_batch_size=16, image_batch_size=8
87
+ )
88
+
89
+ results = kit.embed_text("Hello world")
90
+ assert len(results.objects) == 1
91
+ assert len(results.objects[0].embedding.shape) == 2
92
+ assert results.objects[0].source_b64 == None
93
+
94
+ results = kit.embed_image(sample_image)
95
+ assert len(results.objects) == 1
96
+ assert len(results.objects[0].embedding.shape) == 2
97
+ assert type(results.objects[0].source_b64) == str
98
+
99
+ results = kit.embed_pdf(sample_pdf)
100
+ assert len(results.objects) == 1
101
+ assert len(results.objects[0].embedding.shape) == 2
102
+ assert type(results.objects[0].source_b64) == str
103
+
104
+ results = kit.embed_pdf(longer_pdf)
105
+ assert len(results.objects) == 5
106
+ assert len(results.objects[0].embedding.shape) == 2
107
+ assert type(results.objects[0].source_b64) == str
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "embedkit"
3
- version = "0.1.4"
3
+ version = "0.1.6"
4
4
  description = "A simple toolkit for generating vector embeddings across multiple providers and models"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10"
@@ -11,7 +11,7 @@ dependencies = [
11
11
  "pdf2image>=1.17.0",
12
12
  "pillow>=11.2.1",
13
13
  "torch<=2.5",
14
- "transformers",
14
+ "transformers>=4.46.2",
15
15
  ]
16
16
  authors = [
17
17
  {name = "JP Hwang", email = "me@jphwang.com"},
@@ -5,7 +5,6 @@ EmbedKit: A unified toolkit for generating vector embeddings.
5
5
 
6
6
  from typing import Union, List, Optional
7
7
  from pathlib import Path
8
- import numpy as np
9
8
 
10
9
  from .models import Model
11
10
  from .base import EmbeddingError, EmbeddingResponse
@@ -28,7 +27,7 @@ class EmbedKit:
28
27
  @classmethod
29
28
  def colpali(
30
29
  cls,
31
- model: Model = Model.ColPali.V1_3,
30
+ model: Model = Model.ColPali.COLPALI_V1_3,
32
31
  device: Optional[str] = None,
33
32
  text_batch_size: int = 32,
34
33
  image_batch_size: int = 8,
@@ -42,13 +41,13 @@ class EmbedKit:
42
41
  text_batch_size: Batch size for text embedding generation
43
42
  image_batch_size: Batch size for image embedding generation
44
43
  """
45
- if model == Model.ColPali.V1_3:
46
- model_name = "vidore/colpali-v1.3"
47
- else:
48
- raise ValueError(f"Unsupported model: {model}")
44
+ if not isinstance(model, Model.ColPali):
45
+ raise ValueError(
46
+ f"Unsupported model: {model}. Must be a Model.ColPali enum value."
47
+ )
49
48
 
50
49
  provider = ColPaliProvider(
51
- model_name=model_name,
50
+ model=model,
52
51
  device=device,
53
52
  text_batch_size=text_batch_size,
54
53
  image_batch_size=image_batch_size,
@@ -77,16 +76,15 @@ class EmbedKit:
77
76
  if not api_key:
78
77
  raise ValueError("API key is required")
79
78
 
80
- if model == Model.Cohere.EMBED_V4_0:
81
- model_name = "embed-v4.0"
82
- else:
79
+ if not isinstance(model, Model.Cohere):
83
80
  raise ValueError(f"Unsupported model: {model}")
84
81
 
85
82
  provider = CohereProvider(
86
- api_key=api_key, model_name=model_name,
83
+ api_key=api_key,
84
+ model=model,
87
85
  text_batch_size=text_batch_size,
88
86
  image_batch_size=image_batch_size,
89
- text_input_type=text_input_type
87
+ text_input_type=text_input_type,
90
88
  )
91
89
  return cls(provider)
92
90
 
@@ -0,0 +1,122 @@
1
+ # ./src/embedkit/base.py
2
+ """Base classes for EmbedKit."""
3
+
4
+ from abc import ABC, abstractmethod
5
+ from typing import Union, List, Optional
6
+ from pathlib import Path
7
+ import numpy as np
8
+ from dataclasses import dataclass
9
+
10
+ from .models import Model
11
+ from .utils import with_pdf_cleanup
12
+
13
+
14
+ @dataclass
15
+ class EmbeddingObject:
16
+ embedding: np.ndarray
17
+ source_b64: str = None
18
+ source_content_type: str = None # e.g., "image/png", "image/jpeg"
19
+
20
+
21
+ @dataclass
22
+ class EmbeddingResponse:
23
+ model_name: str
24
+ model_provider: str
25
+ input_type: str
26
+ objects: List[EmbeddingObject]
27
+
28
+ @property
29
+ def shape(self) -> tuple:
30
+ return self.objects[0].embedding.shape
31
+
32
+
33
+ class EmbeddingProvider(ABC):
34
+ """Abstract base class for embedding providers."""
35
+
36
+ def __init__(
37
+ self,
38
+ model_name: str,
39
+ text_batch_size: int,
40
+ image_batch_size: int,
41
+ provider_name: str,
42
+ ):
43
+ self.model_name = model_name
44
+ self.provider_name = provider_name
45
+ self.text_batch_size = text_batch_size
46
+ self.image_batch_size = image_batch_size
47
+
48
+ def _normalize_text_input(self, texts: Union[str, List[str]]) -> List[str]:
49
+ """Normalize text input to a list of strings."""
50
+ if isinstance(texts, str):
51
+ return [texts]
52
+ return texts
53
+
54
+ def _normalize_image_input(
55
+ self, images: Union[Path, str, List[Union[Path, str]]]
56
+ ) -> List[Path]:
57
+ """Normalize image input to a list of Path objects."""
58
+ if isinstance(images, (str, Path)):
59
+ return [Path(images)]
60
+ return [Path(img) for img in images]
61
+
62
+ def _create_text_response(
63
+ self, embeddings: List[np.ndarray], input_type: str = "text"
64
+ ) -> EmbeddingResponse:
65
+ """Create a standardized text embedding response."""
66
+ return EmbeddingResponse(
67
+ model_name=self.model_name,
68
+ model_provider=self.provider_name,
69
+ input_type=input_type,
70
+ objects=[EmbeddingObject(embedding=e) for e in embeddings],
71
+ )
72
+
73
+ def _create_image_response(
74
+ self,
75
+ embeddings: List[np.ndarray],
76
+ b64_data: List[str],
77
+ content_types: List[str],
78
+ input_type: str = "image",
79
+ ) -> EmbeddingResponse:
80
+ """Create a standardized image embedding response."""
81
+ return EmbeddingResponse(
82
+ model_name=self.model_name,
83
+ model_provider=self.provider_name,
84
+ input_type=input_type,
85
+ objects=[
86
+ EmbeddingObject(
87
+ embedding=embedding,
88
+ source_b64=b64_data,
89
+ source_content_type=content_type,
90
+ )
91
+ for embedding, b64_data, content_type in zip(
92
+ embeddings, b64_data, content_types
93
+ )
94
+ ],
95
+ )
96
+
97
+ @abstractmethod
98
+ def embed_text(self, texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
99
+ """Generate document text embeddings using the configured provider."""
100
+ pass
101
+
102
+ @abstractmethod
103
+ def embed_image(
104
+ self, images: Union[Path, str, List[Union[Path, str]]]
105
+ ) -> EmbeddingResponse:
106
+ """Generate image embeddings using the configured provider."""
107
+ pass
108
+
109
+ def embed_pdf(self, pdf_path: Path) -> EmbeddingResponse:
110
+ """Generate embeddings for a PDF file."""
111
+ return self._embed_pdf_impl(pdf_path)
112
+
113
+ @with_pdf_cleanup
114
+ def _embed_pdf_impl(self, pdf_path: List[Path]) -> EmbeddingResponse:
115
+ """Internal implementation of PDF embedding with cleanup handled by decorator."""
116
+ return self.embed_image(pdf_path)
117
+
118
+
119
+ class EmbeddingError(Exception):
120
+ """Base exception for embedding-related errors."""
121
+
122
+ pass
@@ -13,9 +13,4 @@ from . import EmbeddingResponse, EmbeddingError
13
13
  from .models import Model
14
14
  from .providers.cohere import CohereInputType
15
15
 
16
- __all__ = [
17
- "EmbeddingResponse",
18
- "EmbeddingError",
19
- "Model",
20
- "CohereInputType"
21
- ]
16
+ __all__ = ["EmbeddingResponse", "EmbeddingError", "Model", "CohereInputType"]
@@ -0,0 +1,18 @@
1
+ # ./src/embedkit/models.py
2
+ """Model definitions and enum for EmbedKit."""
3
+
4
+ from enum import Enum
5
+
6
+
7
+ class Model:
8
+ class ColPali(Enum):
9
+ COLPALI_V1_3 = "vidore/colpali-v1.3"
10
+ COLSMOL_500M = "vidore/colSmol-500M"
11
+ COLSMOL_256M = "vidore/colSmol-256M"
12
+
13
+ class Cohere(Enum):
14
+ EMBED_V4_0 = "embed-v4.0"
15
+ EMBED_ENGLISH_V3_0 = "embed-english-v3.0"
16
+ EMBED_ENGLISH_LIGHT_V3_0 = "embed-english-light-v3.0"
17
+ EMBED_MULTILINGUAL_V3_0 = "embed-multilingual-v3.0"
18
+ EMBED_MULTILINGUAL_LIGHT_V3_0 = "embed-multilingual-light-v3.0"
@@ -5,9 +5,13 @@ from typing import Union, List
5
5
  from pathlib import Path
6
6
  import numpy as np
7
7
  from enum import Enum
8
+ import logging
8
9
 
9
- from ..utils import pdf_to_images, image_to_base64
10
- from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResponse, EmbeddingObject
10
+ from ..models import Model
11
+ from ..utils import image_to_base64
12
+ from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResponse
13
+
14
+ logger = logging.getLogger(__name__)
11
15
 
12
16
 
13
17
  class CohereInputType(Enum):
@@ -23,18 +27,20 @@ class CohereProvider(EmbeddingProvider):
23
27
  def __init__(
24
28
  self,
25
29
  api_key: str,
26
- model_name: str,
30
+ model: Model.Cohere,
27
31
  text_batch_size: int,
28
32
  image_batch_size: int,
29
33
  text_input_type: CohereInputType = CohereInputType.SEARCH_DOCUMENT,
30
34
  ):
35
+ super().__init__(
36
+ model_name=model.value,
37
+ text_batch_size=text_batch_size,
38
+ image_batch_size=image_batch_size,
39
+ provider_name="Cohere",
40
+ )
31
41
  self.api_key = api_key
32
- self.model_name = model_name
33
- self.text_batch_size = text_batch_size
34
- self.image_batch_size = image_batch_size
35
42
  self.input_type = text_input_type
36
43
  self._client = None
37
- self.provider_name = "Cohere"
38
44
 
39
45
  def _get_client(self):
40
46
  """Lazy load the Cohere client."""
@@ -54,9 +60,7 @@ class CohereProvider(EmbeddingProvider):
54
60
  def embed_text(self, texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
55
61
  """Generate text embeddings using the Cohere API."""
56
62
  client = self._get_client()
57
-
58
- if isinstance(texts, str):
59
- texts = [texts]
63
+ texts = self._normalize_text_input(texts)
60
64
 
61
65
  try:
62
66
  all_embeddings = []
@@ -72,16 +76,7 @@ class CohereProvider(EmbeddingProvider):
72
76
  )
73
77
  all_embeddings.extend(np.array(response.embeddings.float_))
74
78
 
75
- return EmbeddingResponse(
76
- model_name=self.model_name,
77
- model_provider=self.provider_name,
78
- input_type=self.input_type.value,
79
- objects=[
80
- EmbeddingObject(
81
- embedding=e,
82
- ) for e in all_embeddings
83
- ]
84
- )
79
+ return self._create_text_response(all_embeddings, self.input_type.value)
85
80
 
86
81
  except Exception as e:
87
82
  raise EmbeddingError(f"Failed to embed text with Cohere: {e}") from e
@@ -92,12 +87,9 @@ class CohereProvider(EmbeddingProvider):
92
87
  ) -> EmbeddingResponse:
93
88
  """Generate embeddings for images using Cohere API."""
94
89
  client = self._get_client()
95
- input_type = "image"
96
-
97
- if isinstance(images, (str, Path)):
98
- images = [Path(images)]
99
- else:
100
- images = [Path(img) for img in images]
90
+ images = self._normalize_image_input(images)
91
+ total_images = len(images)
92
+ logger.info(f"Starting to process {total_images} images")
101
93
 
102
94
  try:
103
95
  all_embeddings = []
@@ -106,12 +98,17 @@ class CohereProvider(EmbeddingProvider):
106
98
  # Process images in batches
107
99
  for i in range(0, len(images), self.image_batch_size):
108
100
  batch_images = images[i : i + self.image_batch_size]
101
+ logger.info(f"Processing batch {i//self.image_batch_size + 1} of {(total_images + self.image_batch_size - 1)//self.image_batch_size} ({len(batch_images)} images)")
109
102
  b64_images = []
110
103
 
111
104
  for image in batch_images:
112
105
  if not image.exists():
113
106
  raise EmbeddingError(f"Image not found: {image}")
114
- b64_images.append(image_to_base64(image))
107
+ b64_data, content_type = image_to_base64(image)
108
+ # Construct full data URI for API
109
+ data_uri = f"data:{content_type};base64,{b64_data}"
110
+ b64_images.append(data_uri)
111
+ all_b64_images.append((b64_data, content_type))
115
112
 
116
113
  response = client.embed(
117
114
  model=self.model_name,
@@ -121,24 +118,14 @@ class CohereProvider(EmbeddingProvider):
121
118
  )
122
119
 
123
120
  all_embeddings.extend(np.array(response.embeddings.float_))
124
- all_b64_images.extend(b64_images)
125
-
126
- return EmbeddingResponse(
127
- model_name=self.model_name,
128
- model_provider=self.provider_name,
129
- input_type=input_type,
130
- objects=[
131
- EmbeddingObject(
132
- embedding=all_embeddings[i],
133
- source_b64=all_b64_images[i]
134
- ) for i in range(len(all_embeddings))
135
- ]
121
+
122
+ logger.info(f"Successfully processed all {total_images} images")
123
+ return self._create_image_response(
124
+ all_embeddings,
125
+ [b64 for b64, _ in all_b64_images],
126
+ [content_type for _, content_type in all_b64_images],
136
127
  )
137
128
 
138
129
  except Exception as e:
130
+ logger.error(f"Failed to embed images: {e}")
139
131
  raise EmbeddingError(f"Failed to embed image with Cohere: {e}") from e
140
-
141
- def embed_pdf(self, pdf_path: Path) -> EmbeddingResponse:
142
- """Generate embeddings for a PDF file using Cohere API."""
143
- image_paths = pdf_to_images(pdf_path)
144
- return self.embed_image(image_paths)
@@ -0,0 +1,162 @@
1
+ # ./src/embedkit/providers/colpali.py
2
+ """ColPali embedding provider."""
3
+
4
+ from typing import Union, List, Optional
5
+ from pathlib import Path
6
+ import logging
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+
11
+ from ..models import Model
12
+ from ..utils import image_to_base64
13
+ from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResponse
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class ColPaliProvider(EmbeddingProvider):
19
+ """ColPali embedding provider for document understanding."""
20
+
21
+ def __init__(
22
+ self,
23
+ model: Model.ColPali,
24
+ text_batch_size: int,
25
+ image_batch_size: int,
26
+ device: Optional[str] = None,
27
+ ):
28
+ super().__init__(
29
+ model_name=model.value,
30
+ text_batch_size=text_batch_size,
31
+ image_batch_size=image_batch_size,
32
+ provider_name="ColPali",
33
+ )
34
+
35
+ # Auto-detect device
36
+ if device is None:
37
+ if torch.cuda.is_available():
38
+ device = "cuda"
39
+ elif torch.backends.mps.is_available():
40
+ device = "mps"
41
+ else:
42
+ device = "cpu"
43
+
44
+ self._hf_device = device
45
+ self._hf_model = None
46
+ self._hf_processor = None
47
+
48
+ def _load_model(self):
49
+ """Lazy load the model."""
50
+ if self._hf_model is None:
51
+ try:
52
+ if self.model_name in [Model.ColPali.COLPALI_V1_3.value]:
53
+ from colpali_engine.models import ColPali, ColPaliProcessor
54
+
55
+ self._hf_model = ColPali.from_pretrained(
56
+ self.model_name,
57
+ torch_dtype=torch.bfloat16,
58
+ device_map=self._hf_device,
59
+ ).eval()
60
+
61
+ self._hf_processor = ColPaliProcessor.from_pretrained(self.model_name)
62
+
63
+ elif self.model_name in [
64
+ Model.ColPali.COLSMOL_500M.value,
65
+ Model.ColPali.COLSMOL_256M.value,
66
+ ]:
67
+ from colpali_engine.models import ColIdefics3, ColIdefics3Processor
68
+
69
+ self._hf_model = ColIdefics3.from_pretrained(
70
+ self.model_name,
71
+ torch_dtype=torch.bfloat16,
72
+ device_map=self._hf_device,
73
+ ).eval()
74
+ self._hf_processor = ColIdefics3Processor.from_pretrained(
75
+ self.model_name
76
+ )
77
+ else:
78
+ raise ValueError(f"Unable to load model for: {self.model_name}.")
79
+
80
+ logger.info(f"Loaded {self.model_name} on {self._hf_device}")
81
+
82
+ except ImportError as e:
83
+ raise EmbeddingError(
84
+ "ColPali not installed. Run: pip install colpali-engine"
85
+ ) from e
86
+ except Exception as e:
87
+ raise EmbeddingError(f"Failed to load model: {e}") from e
88
+
89
+ def embed_text(self, texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
90
+ """Generate embeddings for text inputs."""
91
+ self._load_model()
92
+ texts = self._normalize_text_input(texts)
93
+
94
+ try:
95
+ # Process texts in batches
96
+ all_embeddings: List[np.ndarray] = []
97
+
98
+ for i in range(0, len(texts), self.text_batch_size):
99
+ batch_texts = texts[i : i + self.text_batch_size]
100
+ processed = self._hf_processor.process_queries(batch_texts).to(self._hf_device)
101
+
102
+ with torch.no_grad():
103
+ batch_embeddings = self._hf_model(**processed)
104
+ all_embeddings.append(batch_embeddings.cpu().float().numpy())
105
+
106
+ # Concatenate all batch embeddings
107
+ final_embeddings = np.concatenate(all_embeddings, axis=0)
108
+ return self._create_text_response(final_embeddings)
109
+
110
+ except Exception as e:
111
+ raise EmbeddingError(f"Failed to embed text: {e}") from e
112
+
113
+ def embed_image(
114
+ self, images: Union[Path, str, List[Union[Path, str]]]
115
+ ) -> EmbeddingResponse:
116
+ """Generate embeddings for images."""
117
+ self._load_model()
118
+ images = self._normalize_image_input(images)
119
+ total_images = len(images)
120
+ logger.info(f"Starting to process {total_images} images")
121
+
122
+ try:
123
+ # Process images in batches
124
+ all_embeddings: List[np.ndarray] = []
125
+ all_b64_data: List[str] = []
126
+ all_content_types: List[str] = []
127
+
128
+ for i in range(0, len(images), self.image_batch_size):
129
+ batch_images = images[i : i + self.image_batch_size]
130
+ logger.info(f"Processing batch {i//self.image_batch_size + 1} of {(total_images + self.image_batch_size - 1)//self.image_batch_size} ({len(batch_images)} images)")
131
+ pil_images = []
132
+ batch_b64_data = []
133
+ batch_content_types = []
134
+
135
+ for img_path in batch_images:
136
+ if not img_path.exists():
137
+ raise EmbeddingError(f"Image not found: {img_path}")
138
+
139
+ with Image.open(img_path) as img:
140
+ pil_images.append(img.convert("RGB"))
141
+ b64, content_type = image_to_base64(img_path)
142
+ batch_b64_data.append(b64)
143
+ batch_content_types.append(content_type)
144
+
145
+ processed = self._hf_processor.process_images(pil_images).to(self._hf_device)
146
+
147
+ with torch.no_grad():
148
+ batch_embeddings = self._hf_model(**processed)
149
+ all_embeddings.append(batch_embeddings.cpu().float().numpy())
150
+ all_b64_data.extend(batch_b64_data)
151
+ all_content_types.extend(batch_content_types)
152
+
153
+ # Concatenate all batch embeddings
154
+ final_embeddings = np.concatenate(all_embeddings, axis=0)
155
+ logger.info(f"Successfully processed all {total_images} images")
156
+ return self._create_image_response(
157
+ final_embeddings, all_b64_data, all_content_types
158
+ )
159
+
160
+ except Exception as e:
161
+ logger.error(f"Failed to embed images: {e}")
162
+ raise EmbeddingError(f"Failed to embed images: {e}") from e
@@ -0,0 +1,142 @@
1
+ import tempfile
2
+ import shutil
3
+ import logging
4
+ from contextlib import contextmanager
5
+ from pdf2image import convert_from_path
6
+ from pathlib import Path
7
+ from .config import get_temp_dir
8
+ from typing import Union, List, Iterator, Callable, TypeVar, Any
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ T = TypeVar("T")
13
+
14
+
15
+ @contextmanager
16
+ def temporary_directory() -> Iterator[Path]:
17
+ """Create a temporary directory that is automatically cleaned up when done.
18
+
19
+ Yields:
20
+ Path: Path to the temporary directory
21
+ """
22
+ temp_dir = Path(tempfile.mkdtemp())
23
+ try:
24
+ yield temp_dir
25
+ finally:
26
+ shutil.rmtree(temp_dir)
27
+
28
+
29
+ def pdf_to_images(pdf_path: Path) -> List[Path]:
30
+ """Convert a PDF file to a list of images.
31
+
32
+ The images are stored in a temporary directory that will be automatically
33
+ cleaned up when the process exits.
34
+
35
+ Args:
36
+ pdf_path: Path to the PDF file
37
+
38
+ Returns:
39
+ List[Path]: List of paths to the generated images
40
+
41
+ Note:
42
+ The temporary files will be automatically cleaned up when the process exits.
43
+ Do not rely on these files persisting after the function returns.
44
+ """
45
+ with temporary_directory() as temp_dir:
46
+ images = convert_from_path(pdf_path=str(pdf_path), output_folder=str(temp_dir))
47
+ image_paths = []
48
+
49
+ for i, image in enumerate(images):
50
+ output_path = temp_dir / f"{pdf_path.stem}_{i}.png"
51
+ image.save(output_path)
52
+ final_path = Path(tempfile.mktemp(suffix=".png"))
53
+ shutil.move(output_path, final_path)
54
+ image_paths.append(final_path)
55
+
56
+ return image_paths
57
+
58
+
59
+ def image_to_base64(image_path: Union[str, Path]) -> tuple[str, str]:
60
+ """Convert an image to base64 and return the base64 data and content type.
61
+
62
+ Args:
63
+ image_path: Path to the image file
64
+
65
+ Returns:
66
+ tuple[str, str]: (base64_data, content_type)
67
+
68
+ Raises:
69
+ ValueError: If the image cannot be read or has an unsupported format
70
+ """
71
+ import base64
72
+
73
+ try:
74
+ base64_data = base64.b64encode(Path(image_path).read_bytes()).decode("utf-8")
75
+ except Exception as e:
76
+ raise ValueError(f"Failed to read image {image_path}: {e}") from e
77
+
78
+ if isinstance(image_path, Path):
79
+ image_path_str = str(image_path)
80
+ else:
81
+ image_path_str = image_path
82
+
83
+ if image_path_str.lower().endswith(".png"):
84
+ content_type = "image/png"
85
+ elif image_path_str.lower().endswith((".jpg", ".jpeg")):
86
+ content_type = "image/jpeg"
87
+ elif image_path_str.lower().endswith(".gif"):
88
+ content_type = "image/gif"
89
+ else:
90
+ raise ValueError(
91
+ f"Unsupported image format for {image_path}; expected .png, .jpg, .jpeg, or .gif"
92
+ )
93
+
94
+ return base64_data, content_type
95
+
96
+
97
+ def with_pdf_cleanup(embed_func: Callable[..., T]) -> Callable[..., T]:
98
+ """Decorator to handle PDF to image conversion with automatic cleanup.
99
+
100
+ This decorator handles the common pattern of:
101
+ 1. Converting PDF to images
102
+ 2. Passing images to an embedding function
103
+ 3. Cleaning up temporary files
104
+
105
+ Args:
106
+ embed_func: Function that takes a list of image paths and returns embeddings
107
+
108
+ Returns:
109
+ Callable that takes a PDF path and returns embeddings
110
+ """
111
+
112
+ def wrapper(*args, **kwargs) -> T:
113
+ # First argument is self for instance methods
114
+ pdf_path = args[-1] if args else kwargs.get("pdf_path")
115
+ if not pdf_path:
116
+ raise ValueError(
117
+ "PDF path must be provided as the last positional argument or as 'pdf_path' keyword argument"
118
+ )
119
+
120
+ images = [] # Initialize images as empty list
121
+ try:
122
+ images = pdf_to_images(pdf_path)
123
+ # Call the original function with the images instead of pdf_path
124
+ if args:
125
+ # For instance methods, replace the last argument (pdf_path) with images
126
+ args = list(args)
127
+ args[-1] = images
128
+ else:
129
+ kwargs["pdf_path"] = images
130
+ return embed_func(*args, **kwargs)
131
+ finally:
132
+ # Clean up temporary files created by pdf_to_images
133
+ for img_path in images:
134
+ try:
135
+ if img_path.exists() and str(img_path).startswith(
136
+ tempfile.gettempdir()
137
+ ):
138
+ img_path.unlink()
139
+ except Exception as e:
140
+ logger.warning(f"Failed to clean up temporary file {img_path}: {e}")
141
+
142
+ return wrapper
@@ -114,7 +114,7 @@ def test_cohere_missing_api_key():
114
114
  # ===============================
115
115
  def test_colpali_text_embedding():
116
116
  """Test text embedding with Colpali model."""
117
- kit = EmbedKit.colpali(model=Model.ColPali.V1_3)
117
+ kit = EmbedKit.colpali(model=Model.ColPali.COLPALI_V1_3)
118
118
  result = kit.embed_text("Hello world")
119
119
 
120
120
  assert len(result.objects) == 1
@@ -133,7 +133,7 @@ def test_colpali_text_embedding():
133
133
  )
134
134
  def test_colpali_file_embedding(request, embed_method, file_fixture):
135
135
  """Test file embedding with Colpali model."""
136
- kit = EmbedKit.colpali(model=Model.ColPali.V1_3)
136
+ kit = EmbedKit.colpali(model=Model.ColPali.COLPALI_V1_3)
137
137
  file_path = request.getfixturevalue(file_fixture)
138
138
  embed_func = getattr(kit, embed_method)
139
139
  result = embed_func(file_path)
@@ -0,0 +1,52 @@
1
+ import pytest
2
+ from pathlib import Path
3
+ import tempfile
4
+ from embedkit.utils import temporary_directory, pdf_to_images
5
+
6
+
7
+ @pytest.fixture
8
+ def sample_pdf_path():
9
+ """Fixture to provide a sample PDF for testing."""
10
+ path = Path("tests/fixtures/2407.01449v6_p1.pdf")
11
+ if not path.exists():
12
+ pytest.skip(f"Test fixture not found: {path}")
13
+ return path
14
+
15
+
16
+ def test_temporary_directory_cleanup():
17
+ """Test that temporary directory is properly cleaned up after use."""
18
+ temp_dir = None
19
+ with temporary_directory() as temp_path:
20
+ temp_dir = temp_path
21
+ # Create a test file in the temp directory
22
+ test_file = temp_path / "test.txt"
23
+ test_file.write_text("test content")
24
+ assert test_file.exists()
25
+ assert temp_path.exists()
26
+
27
+ # After the context manager exits, both the file and directory should be gone
28
+ assert not temp_dir.exists()
29
+ assert not test_file.exists()
30
+
31
+
32
+ def test_pdf_to_images_temporary_files(sample_pdf_path):
33
+ """Test that PDF to images conversion creates and cleans up temporary files properly."""
34
+ # Convert PDF to images
35
+ image_paths = pdf_to_images(sample_pdf_path)
36
+
37
+ # Check that we got image paths
38
+ assert len(image_paths) > 0
39
+
40
+ # Verify all images exist and are in temp directory
41
+ for img_path in image_paths:
42
+ assert img_path.exists()
43
+ assert str(img_path).startswith(tempfile.gettempdir())
44
+ assert img_path.suffix == ".png"
45
+
46
+ # Verify the image is readable
47
+ assert img_path.stat().st_size > 0
48
+
49
+ # Clean up the temporary files
50
+ for img_path in image_paths:
51
+ img_path.unlink()
52
+ assert not img_path.exists()
@@ -349,7 +349,7 @@ wheels = [
349
349
 
350
350
  [[package]]
351
351
  name = "embedkit"
352
- version = "0.1.0"
352
+ version = "0.1.5"
353
353
  source = { editable = "." }
354
354
  dependencies = [
355
355
  { name = "accelerate" },
@@ -377,7 +377,7 @@ requires-dist = [
377
377
  { name = "pdf2image", specifier = ">=1.17.0" },
378
378
  { name = "pillow", specifier = ">=11.2.1" },
379
379
  { name = "torch", specifier = "<=2.5" },
380
- { name = "transformers" },
380
+ { name = "transformers", specifier = ">=4.46.2" },
381
381
  ]
382
382
 
383
383
  [package.metadata.requires-dev]
@@ -1,53 +0,0 @@
1
- # ./src/embedkit/base.py
2
- """Base classes for EmbedKit."""
3
-
4
- from abc import ABC, abstractmethod
5
- from typing import Union, List, Optional
6
- from pathlib import Path
7
- import numpy as np
8
- from dataclasses import dataclass
9
-
10
-
11
- @dataclass
12
- class EmbeddingObject:
13
- embedding: np.ndarray
14
- source_b64: str = None
15
-
16
-
17
- @dataclass
18
- class EmbeddingResponse:
19
- model_name: str
20
- model_provider: str
21
- input_type: str
22
- objects: List[EmbeddingObject]
23
-
24
- @property
25
- def shape(self) -> tuple:
26
- return self.objects[0].embedding.shape
27
-
28
-
29
- class EmbeddingProvider(ABC):
30
- """Abstract base class for embedding providers."""
31
-
32
- @abstractmethod
33
- def embed_text(self, texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse:
34
- """Generate document text embeddings using the configured provider."""
35
- pass
36
-
37
- @abstractmethod
38
- def embed_image(
39
- self, images: Union[Path, str, List[Union[Path, str]]]
40
- ) -> EmbeddingResponse:
41
- """Generate image embeddings using the configured provider."""
42
- pass
43
-
44
- @abstractmethod
45
- def embed_pdf(self, pdf: Union[Path, str]) -> EmbeddingResponse:
46
- """Generate image embeddings from PDFsusing the configured provider. Takes a single PDF file."""
47
- pass
48
-
49
-
50
- class EmbeddingError(Exception):
51
- """Base exception for embedding-related errors."""
52
-
53
- pass
@@ -1,12 +0,0 @@
1
- # ./src/embedkit/models.py
2
- """Model definitions and enum for EmbedKit."""
3
-
4
- from enum import Enum
5
-
6
-
7
- class Model:
8
- class ColPali(Enum):
9
- V1_3 = "colpali-v1.3"
10
-
11
- class Cohere(Enum):
12
- EMBED_V4_0 = "embed-v4.0"
@@ -1,160 +0,0 @@
1
- # ./src/embedkit/providers/colpali.py
2
- """ColPali embedding provider."""
3
-
4
- from typing import Union, List, Optional
5
- from pathlib import Path
6
- import logging
7
- import numpy as np
8
- import torch
9
- from PIL import Image
10
-
11
- from ..utils import pdf_to_images, image_to_base64
12
- from ..base import EmbeddingProvider, EmbeddingError, EmbeddingResponse, EmbeddingObject
13
-
14
- logger = logging.getLogger(__name__)
15
-
16
-
17
- class ColPaliProvider(EmbeddingProvider):
18
- """ColPali embedding provider for document understanding."""
19
-
20
- def __init__(
21
- self,
22
- model_name: str,
23
- text_batch_size: int,
24
- image_batch_size: int,
25
- device: Optional[str] = None,
26
- ):
27
- self.model_name = model_name
28
- self.provider_name = "ColPali"
29
- self.text_batch_size = text_batch_size
30
- self.image_batch_size = image_batch_size
31
-
32
- # Auto-detect device
33
- if device is None:
34
- if torch.cuda.is_available():
35
- device = "cuda"
36
- elif torch.backends.mps.is_available():
37
- device = "mps"
38
- else:
39
- device = "cpu"
40
-
41
- self.device = device
42
- self._model = None
43
- self._processor = None
44
-
45
- def _load_model(self):
46
- """Lazy load the model."""
47
- if self._model is None:
48
- try:
49
- from colpali_engine.models import ColPali, ColPaliProcessor
50
-
51
- self._model = ColPali.from_pretrained(
52
- self.model_name,
53
- torch_dtype=torch.bfloat16,
54
- device_map=self.device,
55
- ).eval()
56
-
57
- self._processor = ColPaliProcessor.from_pretrained(self.model_name)
58
- logger.info(f"Loaded ColPali model on {self.device}")
59
-
60
- except ImportError as e:
61
- raise EmbeddingError(
62
- "ColPali not installed. Run: pip install colpali-engine"
63
- ) from e
64
- except Exception as e:
65
- raise EmbeddingError(f"Failed to load model: {e}") from e
66
-
67
- def embed_text(self, texts: Union[str, List[str]]) -> EmbeddingResponse:
68
- """Generate embeddings for text inputs."""
69
- self._load_model()
70
-
71
- if isinstance(texts, str):
72
- texts = [texts]
73
-
74
- try:
75
- # Process texts in batches
76
- all_embeddings = []
77
-
78
- for i in range(0, len(texts), self.text_batch_size):
79
- batch_texts = texts[i : i + self.text_batch_size]
80
- processed = self._processor.process_queries(batch_texts).to(self.device)
81
-
82
- with torch.no_grad():
83
- batch_embeddings = self._model(**processed)
84
- all_embeddings.append(batch_embeddings.cpu().float().numpy())
85
-
86
- # Concatenate all batch embeddings
87
- final_embeddings = np.concatenate(all_embeddings, axis=0)
88
-
89
- return EmbeddingResponse(
90
- model_name=self.model_name,
91
- model_provider=self.provider_name,
92
- input_type="text",
93
- objects=[
94
- EmbeddingObject(
95
- embedding=e,
96
- ) for e in final_embeddings
97
- ]
98
- )
99
-
100
- except Exception as e:
101
- raise EmbeddingError(f"Failed to embed text: {e}") from e
102
-
103
- def embed_image(
104
- self, images: Union[Path, str, List[Union[Path, str]]]
105
- ) -> EmbeddingResponse:
106
- """Generate embeddings for images."""
107
- self._load_model()
108
-
109
- if isinstance(images, (str, Path)):
110
- images = [Path(images)]
111
- else:
112
- images = [Path(img) for img in images]
113
-
114
- try:
115
- # Process images in batches
116
- all_embeddings = []
117
- all_b64_images = []
118
-
119
- for i in range(0, len(images), self.image_batch_size):
120
- batch_images = images[i : i + self.image_batch_size]
121
- pil_images = []
122
- b64_images = []
123
-
124
- for img_path in batch_images:
125
- if not img_path.exists():
126
- raise EmbeddingError(f"Image not found: {img_path}")
127
-
128
- with Image.open(img_path) as img:
129
- pil_images.append(img.convert("RGB"))
130
- b64_images.append(image_to_base64(img_path))
131
-
132
- processed = self._processor.process_images(pil_images).to(self.device)
133
-
134
- with torch.no_grad():
135
- batch_embeddings = self._model(**processed)
136
- all_embeddings.append(batch_embeddings.cpu().float().numpy())
137
- all_b64_images.extend(b64_images)
138
-
139
- # Concatenate all batch embeddings
140
- final_embeddings = np.concatenate(all_embeddings, axis=0)
141
-
142
- return EmbeddingResponse(
143
- model_name=self.model_name,
144
- model_provider=self.provider_name,
145
- input_type="image",
146
- objects=[
147
- EmbeddingObject(
148
- embedding=final_embeddings[i],
149
- source_b64=all_b64_images[i]
150
- ) for i in range(len(final_embeddings))
151
- ]
152
- )
153
-
154
- except Exception as e:
155
- raise EmbeddingError(f"Failed to embed images: {e}") from e
156
-
157
- def embed_pdf(self, pdf_path: Path) -> EmbeddingResponse:
158
- """Generate embeddings for a PDF file using ColPali API."""
159
- images = pdf_to_images(pdf_path)
160
- return self.embed_image(images)
@@ -1,48 +0,0 @@
1
- from pdf2image import convert_from_path
2
- from pathlib import Path
3
- from .config import get_temp_dir
4
- from typing import Union
5
-
6
-
7
- def pdf_to_images(pdf_path: Path) -> list[Path]:
8
- """Convert a PDF file to a list of images."""
9
- root_temp_dir = get_temp_dir()
10
- img_temp_dir = root_temp_dir / "images"
11
- img_temp_dir.mkdir(parents=True, exist_ok=True)
12
- images = convert_from_path(pdf_path=str(pdf_path), output_folder=str(img_temp_dir))
13
- image_paths = []
14
-
15
- for i, image in enumerate(images):
16
- output_path = img_temp_dir / f"{pdf_path.stem}_{i}.png"
17
- if output_path.exists():
18
- output_path.unlink()
19
-
20
- image.save(output_path)
21
- image_paths.append(output_path)
22
- return image_paths
23
-
24
-
25
- def image_to_base64(image_path: Union[str, Path]):
26
- import base64
27
-
28
- try:
29
- base64_only = base64.b64encode(Path(image_path).read_bytes()).decode("utf-8")
30
- except Exception as e:
31
- raise ValueError(f"Failed to read image {image_path}: {e}") from e
32
-
33
- if isinstance(image_path, Path):
34
- image_path_str = str(image_path)
35
-
36
- if image_path_str.lower().endswith(".png"):
37
- content_type = "image/png"
38
- elif image_path_str.lower().endswith((".jpg", ".jpeg")):
39
- content_type = "image/jpeg"
40
- elif image_path_str.lower().endswith(".gif"):
41
- content_type = "image/gif"
42
- else:
43
- raise ValueError(
44
- f"Unsupported image format for {image_path}; expected .png, .jpg, .jpeg, or .gif"
45
- )
46
- base64_image = f"data:{content_type};base64,{base64_only}"
47
-
48
- return base64_image
File without changes
File without changes
File without changes
File without changes