kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +6 -0
- kiln_ai/adapters/adapter_registry.py +43 -226
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/eval_runner.py +6 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -3
- kiln_ai/adapters/eval/test_g_eval.py +1 -1
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +494 -0
- kiln_ai/adapters/ml_model_list.py +876 -18
- kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/provider_tools.py +190 -46
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/test_adapter_registry.py +579 -86
- kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
- kiln_ai/adapters/test_ml_model_list.py +202 -0
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +199 -8
- kiln_ai/adapters/test_remote_config.py +551 -56
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +16 -13
- kiln_ai/datamodel/basemodel.py +201 -4
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +27 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +317 -0
- kiln_ai/datamodel/project.py +33 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +270 -14
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +501 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +187 -1
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +58 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/base_tool.py +12 -3
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +158 -0
- kiln_ai/tools/mcp_server_tool.py +2 -2
- kiln_ai/tools/mcp_session_manager.py +51 -22
- kiln_ai/tools/rag_tools.py +164 -0
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +187 -227
- kiln_ai/tools/test_rag_tools.py +929 -0
- kiln_ai/tools/test_tool_registry.py +290 -7
- kiln_ai/tools/tool_registry.py +69 -16
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +2 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +59 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +86 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
- kiln_ai-0.22.0.dist-info/RECORD +213 -0
- kiln_ai-0.20.1.dist-info/RECORD +0 -138
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
from functools import cached_property
|
|
2
|
+
from typing import Any, Dict, List, Tuple
|
|
3
|
+
|
|
4
|
+
import litellm
|
|
5
|
+
from litellm.types.utils import EmbeddingResponse
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
|
|
8
|
+
from kiln_ai.adapters.embedding.base_embedding_adapter import (
|
|
9
|
+
BaseEmbeddingAdapter,
|
|
10
|
+
Embedding,
|
|
11
|
+
EmbeddingResult,
|
|
12
|
+
)
|
|
13
|
+
from kiln_ai.adapters.ml_embedding_model_list import (
|
|
14
|
+
KilnEmbeddingModelProvider,
|
|
15
|
+
built_in_embedding_models_from_provider,
|
|
16
|
+
)
|
|
17
|
+
from kiln_ai.adapters.provider_tools import LiteLlmCoreConfig
|
|
18
|
+
from kiln_ai.datamodel.embedding import EmbeddingConfig
|
|
19
|
+
from kiln_ai.utils.litellm import get_litellm_provider_info
|
|
20
|
+
|
|
21
|
+
# litellm enforces a limit, documented here:
|
|
22
|
+
# https://docs.litellm.ai/docs/embedding/supported_embedding
|
|
23
|
+
# but some providers impose lower limits that LiteLLM does not know about
|
|
24
|
+
# for example, Gemini currently has a limit of 100 inputs per request
|
|
25
|
+
MAX_BATCH_SIZE = 100
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class EmbeddingOptions(BaseModel):
|
|
29
|
+
dimensions: int | None = Field(
|
|
30
|
+
default=None,
|
|
31
|
+
description="The number of dimensions to return for embeddings. Some models support requesting vectors of different dimensions.",
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def validate_map_to_embeddings(
|
|
36
|
+
response: EmbeddingResponse,
|
|
37
|
+
expected_embedding_count: int,
|
|
38
|
+
) -> List[Embedding]:
|
|
39
|
+
# LiteLLM has an Embedding type in litellm.types.utils, but the EmbeddingResponse data has a list of untyped dicts,
|
|
40
|
+
# which can be dangerous especially if we upgrade litellm, so we do some sanity checks here
|
|
41
|
+
if not isinstance(response, EmbeddingResponse):
|
|
42
|
+
raise RuntimeError(f"Expected EmbeddingResponse, got {type(response)}.")
|
|
43
|
+
|
|
44
|
+
list_to_validate = response.data
|
|
45
|
+
if len(list_to_validate) != expected_embedding_count:
|
|
46
|
+
raise RuntimeError(
|
|
47
|
+
f"Expected the number of embeddings in the response to be {expected_embedding_count}, got {len(list_to_validate)}."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
validated_vectors: List[Tuple[list[float], int]] = []
|
|
51
|
+
for embedding_dict in list_to_validate:
|
|
52
|
+
object_type = embedding_dict.get("object")
|
|
53
|
+
if object_type != "embedding":
|
|
54
|
+
raise RuntimeError(
|
|
55
|
+
f"Embedding response data has an unexpected shape. Property 'object' is not 'embedding'. Got {object_type}."
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
embedding_property_value = embedding_dict.get("embedding")
|
|
59
|
+
if embedding_property_value is None:
|
|
60
|
+
raise RuntimeError(
|
|
61
|
+
"Embedding response data has an unexpected shape. Property 'embedding' is None in response data item."
|
|
62
|
+
)
|
|
63
|
+
if not isinstance(embedding_property_value, list):
|
|
64
|
+
raise RuntimeError(
|
|
65
|
+
f"Embedding response data has an unexpected shape. Property 'embedding' is not a list. Got {type(embedding_property_value)}."
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
index_property_value = embedding_dict.get("index")
|
|
69
|
+
if index_property_value is None:
|
|
70
|
+
raise RuntimeError(
|
|
71
|
+
"Embedding response data has an unexpected shape. Property 'index' is None in response data item."
|
|
72
|
+
)
|
|
73
|
+
if not isinstance(index_property_value, int):
|
|
74
|
+
raise RuntimeError(
|
|
75
|
+
f"Embedding response data has an unexpected shape. Property 'index' is not an integer. Got {type(index_property_value)}."
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
validated_vectors.append((embedding_property_value, index_property_value))
|
|
79
|
+
|
|
80
|
+
# sort by index, in place - the data should already be sorted by index,
|
|
81
|
+
# but litellm docs are not explicit about this
|
|
82
|
+
validated_vectors.sort(key=lambda x: x[1])
|
|
83
|
+
|
|
84
|
+
return [
|
|
85
|
+
Embedding(vector=embedding_vector) for embedding_vector, _ in validated_vectors
|
|
86
|
+
]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class LitellmEmbeddingAdapter(BaseEmbeddingAdapter):
|
|
90
|
+
def __init__(
|
|
91
|
+
self, embedding_config: EmbeddingConfig, litellm_core_config: LiteLlmCoreConfig
|
|
92
|
+
):
|
|
93
|
+
super().__init__(embedding_config)
|
|
94
|
+
|
|
95
|
+
self.litellm_core_config = litellm_core_config
|
|
96
|
+
|
|
97
|
+
async def _generate_embeddings(self, input_texts: List[str]) -> EmbeddingResult:
|
|
98
|
+
# batch the requests
|
|
99
|
+
batches: List[List[str]] = []
|
|
100
|
+
for i in range(0, len(input_texts), MAX_BATCH_SIZE):
|
|
101
|
+
batches.append(input_texts[i : i + MAX_BATCH_SIZE])
|
|
102
|
+
|
|
103
|
+
# generate embeddings for each batch
|
|
104
|
+
results: List[EmbeddingResult] = []
|
|
105
|
+
for batch in batches:
|
|
106
|
+
batch_response = await self._generate_embeddings_for_batch(batch)
|
|
107
|
+
results.append(batch_response)
|
|
108
|
+
|
|
109
|
+
# merge the results
|
|
110
|
+
combined_embeddings: List[Embedding] = []
|
|
111
|
+
combined_usage = None
|
|
112
|
+
|
|
113
|
+
# we prefer returning None overall usage if any of the results is missing usage
|
|
114
|
+
# better than returning a misleading usage
|
|
115
|
+
all_have_usage = all(result.usage is not None for result in results)
|
|
116
|
+
if all_have_usage:
|
|
117
|
+
combined_usage = litellm.Usage(
|
|
118
|
+
prompt_tokens=0, total_tokens=0, completion_tokens=0
|
|
119
|
+
)
|
|
120
|
+
for result in results:
|
|
121
|
+
if result.usage is not None:
|
|
122
|
+
combined_usage.prompt_tokens += result.usage.prompt_tokens
|
|
123
|
+
combined_usage.total_tokens += result.usage.total_tokens
|
|
124
|
+
combined_usage.completion_tokens += result.usage.completion_tokens
|
|
125
|
+
|
|
126
|
+
for result in results:
|
|
127
|
+
combined_embeddings.extend(result.embeddings)
|
|
128
|
+
|
|
129
|
+
return EmbeddingResult(
|
|
130
|
+
embeddings=combined_embeddings,
|
|
131
|
+
usage=combined_usage,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
async def _generate_embeddings_for_batch(
|
|
135
|
+
self, input_texts: List[str]
|
|
136
|
+
) -> EmbeddingResult:
|
|
137
|
+
if len(input_texts) > MAX_BATCH_SIZE:
|
|
138
|
+
raise ValueError(
|
|
139
|
+
f"Too many input texts, max batch size is {MAX_BATCH_SIZE}, got {len(input_texts)}"
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
completion_kwargs: Dict[str, Any] = {}
|
|
143
|
+
if self.litellm_core_config.additional_body_options:
|
|
144
|
+
completion_kwargs.update(self.litellm_core_config.additional_body_options)
|
|
145
|
+
|
|
146
|
+
if self.litellm_core_config.base_url:
|
|
147
|
+
completion_kwargs["api_base"] = self.litellm_core_config.base_url
|
|
148
|
+
|
|
149
|
+
if self.litellm_core_config.default_headers:
|
|
150
|
+
completion_kwargs["default_headers"] = (
|
|
151
|
+
self.litellm_core_config.default_headers
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
response = await litellm.aembedding(
|
|
155
|
+
model=self.litellm_model_id,
|
|
156
|
+
input=input_texts,
|
|
157
|
+
**self.build_options().model_dump(exclude_none=True),
|
|
158
|
+
**completion_kwargs,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
validated_embeddings = validate_map_to_embeddings(
|
|
162
|
+
response, expected_embedding_count=len(input_texts)
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
return EmbeddingResult(
|
|
166
|
+
embeddings=validated_embeddings,
|
|
167
|
+
usage=response.usage,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def build_options(self) -> EmbeddingOptions:
|
|
171
|
+
dimensions = self.embedding_config.properties.get("dimensions", None)
|
|
172
|
+
if dimensions is not None:
|
|
173
|
+
if not isinstance(dimensions, int) or dimensions <= 0:
|
|
174
|
+
raise ValueError("Dimensions must be a positive integer")
|
|
175
|
+
|
|
176
|
+
return EmbeddingOptions(
|
|
177
|
+
dimensions=dimensions,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
@cached_property
|
|
181
|
+
def model_provider(self) -> KilnEmbeddingModelProvider:
|
|
182
|
+
provider = built_in_embedding_models_from_provider(
|
|
183
|
+
self.embedding_config.model_provider_name, self.embedding_config.model_name
|
|
184
|
+
)
|
|
185
|
+
if provider is None:
|
|
186
|
+
raise ValueError(
|
|
187
|
+
f"Embedding model {self.embedding_config.model_name} not found in the list of built-in models"
|
|
188
|
+
)
|
|
189
|
+
return provider
|
|
190
|
+
|
|
191
|
+
@cached_property
|
|
192
|
+
def litellm_model_id(self) -> str:
|
|
193
|
+
provider_info = get_litellm_provider_info(self.model_provider)
|
|
194
|
+
if provider_info.is_custom and self.litellm_core_config.base_url is None:
|
|
195
|
+
raise ValueError(
|
|
196
|
+
f"Provider {self.model_provider.name.value} must have an explicit base URL"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
return provider_info.litellm_model_id
|
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from kiln_ai.adapters.embedding.base_embedding_adapter import (
|
|
4
|
+
BaseEmbeddingAdapter,
|
|
5
|
+
Embedding,
|
|
6
|
+
EmbeddingResult,
|
|
7
|
+
)
|
|
8
|
+
from kiln_ai.datamodel.datamodel_enums import ModelProviderName
|
|
9
|
+
from kiln_ai.datamodel.embedding import EmbeddingConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MockEmbeddingAdapter(BaseEmbeddingAdapter):
|
|
13
|
+
"""Concrete implementation of BaseEmbeddingAdapter for testing purposes."""
|
|
14
|
+
|
|
15
|
+
async def _generate_embeddings(self, text_inputs: list[str]) -> EmbeddingResult:
|
|
16
|
+
# Simple test implementation that returns mock embeddings
|
|
17
|
+
embeddings = []
|
|
18
|
+
for i, _ in enumerate(text_inputs):
|
|
19
|
+
embeddings.append(Embedding(vector=[0.1 * (i + 1)] * 3))
|
|
20
|
+
|
|
21
|
+
return EmbeddingResult(embeddings=embeddings)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MockEmbeddingAdapterWithUsage(BaseEmbeddingAdapter):
|
|
25
|
+
"""Concrete implementation that includes usage information."""
|
|
26
|
+
|
|
27
|
+
async def _generate_embeddings(self, text_inputs: list[str]) -> EmbeddingResult:
|
|
28
|
+
from litellm import Usage
|
|
29
|
+
|
|
30
|
+
embeddings = []
|
|
31
|
+
for i, _ in enumerate(text_inputs):
|
|
32
|
+
embeddings.append(Embedding(vector=[0.1 * (i + 1)] * 3))
|
|
33
|
+
|
|
34
|
+
usage = Usage(
|
|
35
|
+
prompt_tokens=len(text_inputs) * 10, total_tokens=len(text_inputs) * 10
|
|
36
|
+
)
|
|
37
|
+
return EmbeddingResult(embeddings=embeddings, usage=usage)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@pytest.fixture
|
|
41
|
+
def mock_embedding_config():
|
|
42
|
+
"""Create a mock embedding config for testing."""
|
|
43
|
+
return EmbeddingConfig(
|
|
44
|
+
name="test-embedding",
|
|
45
|
+
model_provider_name=ModelProviderName.openai,
|
|
46
|
+
model_name="openai_text_embedding_3_small",
|
|
47
|
+
properties={},
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@pytest.fixture
|
|
52
|
+
def test_adapter(mock_embedding_config):
|
|
53
|
+
"""Create a test adapter instance."""
|
|
54
|
+
return MockEmbeddingAdapter(mock_embedding_config)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@pytest.fixture
|
|
58
|
+
def test_adapter_with_usage(mock_embedding_config):
|
|
59
|
+
"""Create a test adapter instance that includes usage information."""
|
|
60
|
+
return MockEmbeddingAdapterWithUsage(mock_embedding_config)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class TestEmbedding:
|
|
64
|
+
"""Test the Embedding model."""
|
|
65
|
+
|
|
66
|
+
def test_creation(self):
|
|
67
|
+
"""Test creating a Embedding with a vector."""
|
|
68
|
+
vector = [0.1, 0.2, 0.3]
|
|
69
|
+
embedding = Embedding(vector=vector)
|
|
70
|
+
assert embedding.vector == vector
|
|
71
|
+
|
|
72
|
+
def test_empty_vector(self):
|
|
73
|
+
"""Test creating a Embedding with an empty vector."""
|
|
74
|
+
embedding = Embedding(vector=[])
|
|
75
|
+
assert embedding.vector == []
|
|
76
|
+
|
|
77
|
+
def test_large_vector(self):
|
|
78
|
+
"""Test creating a Embedding with a large vector."""
|
|
79
|
+
vector = [0.1] * 1536
|
|
80
|
+
embedding = Embedding(vector=vector)
|
|
81
|
+
assert len(embedding.vector) == 1536
|
|
82
|
+
assert all(v == 0.1 for v in embedding.vector)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class TestEmbeddingResult:
|
|
86
|
+
"""Test the EmbeddingResult model."""
|
|
87
|
+
|
|
88
|
+
def test_creation_with_embeddings(self):
|
|
89
|
+
"""Test creating an EmbeddingResult with embeddings."""
|
|
90
|
+
embeddings = [
|
|
91
|
+
Embedding(vector=[0.1, 0.2, 0.3]),
|
|
92
|
+
Embedding(vector=[0.4, 0.5, 0.6]),
|
|
93
|
+
]
|
|
94
|
+
result = EmbeddingResult(embeddings=embeddings)
|
|
95
|
+
assert result.embeddings == embeddings
|
|
96
|
+
assert result.usage is None
|
|
97
|
+
|
|
98
|
+
def test_creation_with_usage(self):
|
|
99
|
+
"""Test creating an EmbeddingResult with usage information."""
|
|
100
|
+
from litellm import Usage
|
|
101
|
+
|
|
102
|
+
embeddings = [Embedding(vector=[0.1, 0.2, 0.3])]
|
|
103
|
+
usage = Usage(prompt_tokens=10, total_tokens=10)
|
|
104
|
+
result = EmbeddingResult(embeddings=embeddings, usage=usage)
|
|
105
|
+
assert result.embeddings == embeddings
|
|
106
|
+
assert result.usage == usage
|
|
107
|
+
|
|
108
|
+
def test_empty_embeddings(self):
|
|
109
|
+
"""Test creating an EmbeddingResult with empty embeddings."""
|
|
110
|
+
result = EmbeddingResult(embeddings=[])
|
|
111
|
+
assert result.embeddings == []
|
|
112
|
+
assert result.usage is None
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class TestBaseEmbeddingAdapter:
|
|
116
|
+
"""Test the BaseEmbeddingAdapter abstract base class."""
|
|
117
|
+
|
|
118
|
+
def test_init(self, mock_embedding_config):
|
|
119
|
+
"""Test successful initialization of the adapter."""
|
|
120
|
+
adapter = MockEmbeddingAdapter(mock_embedding_config)
|
|
121
|
+
assert adapter.embedding_config == mock_embedding_config
|
|
122
|
+
assert adapter.embedding_config.name == "test-embedding"
|
|
123
|
+
assert adapter.embedding_config.model_provider_name == ModelProviderName.openai
|
|
124
|
+
assert adapter.embedding_config.model_name == "openai_text_embedding_3_small"
|
|
125
|
+
assert adapter.embedding_config.properties == {}
|
|
126
|
+
|
|
127
|
+
def test_cannot_instantiate_abstract_class(self, mock_embedding_config):
|
|
128
|
+
"""Test that BaseEmbeddingAdapter cannot be instantiated directly."""
|
|
129
|
+
with pytest.raises(TypeError):
|
|
130
|
+
BaseEmbeddingAdapter(mock_embedding_config)
|
|
131
|
+
|
|
132
|
+
async def test_generate_embeddings_empty_list(self, test_adapter):
|
|
133
|
+
"""Test embed method with empty text list."""
|
|
134
|
+
result = await test_adapter.generate_embeddings([])
|
|
135
|
+
assert result.embeddings == []
|
|
136
|
+
assert result.usage is None
|
|
137
|
+
|
|
138
|
+
async def test_generate_embeddings_single_text(self, test_adapter):
|
|
139
|
+
"""Test embed method with a single text."""
|
|
140
|
+
result = await test_adapter.generate_embeddings(["hello world"])
|
|
141
|
+
assert len(result.embeddings) == 1
|
|
142
|
+
assert result.embeddings[0].vector == [0.1, 0.1, 0.1]
|
|
143
|
+
assert result.usage is None
|
|
144
|
+
|
|
145
|
+
async def test_generate_embeddings_multiple_texts(self, test_adapter):
|
|
146
|
+
"""Test embed method with multiple texts."""
|
|
147
|
+
texts = ["hello world", "my name is john", "i like to eat apples"]
|
|
148
|
+
result = await test_adapter.generate_embeddings(texts)
|
|
149
|
+
assert len(result.embeddings) == 3
|
|
150
|
+
assert result.embeddings[0].vector == [0.1, 0.1, 0.1]
|
|
151
|
+
assert result.embeddings[1].vector == [0.2, 0.2, 0.2]
|
|
152
|
+
assert result.embeddings[2].vector == pytest.approx([0.3, 0.3, 0.3])
|
|
153
|
+
assert result.usage is None
|
|
154
|
+
|
|
155
|
+
async def test_generate_embeddings_with_usage(self, test_adapter_with_usage):
|
|
156
|
+
"""Test embed method with usage information."""
|
|
157
|
+
texts = ["hello", "world"]
|
|
158
|
+
result = await test_adapter_with_usage.generate_embeddings(texts)
|
|
159
|
+
assert len(result.embeddings) == 2
|
|
160
|
+
assert result.embeddings[0].vector == [0.1, 0.1, 0.1]
|
|
161
|
+
assert result.embeddings[1].vector == [0.2, 0.2, 0.2]
|
|
162
|
+
assert result.usage is not None
|
|
163
|
+
assert result.usage.prompt_tokens == 20
|
|
164
|
+
assert result.usage.total_tokens == 20
|
|
165
|
+
|
|
166
|
+
async def test_generate_embeddings_with_none_input(self, test_adapter):
|
|
167
|
+
"""Test embed method with None input (should be treated as empty list)."""
|
|
168
|
+
result = await test_adapter.generate_embeddings(None) # type: ignore
|
|
169
|
+
assert result.embeddings == []
|
|
170
|
+
assert result.usage is None
|
|
171
|
+
|
|
172
|
+
async def test_generate_embeddings_with_whitespace_only_texts(self, test_adapter):
|
|
173
|
+
"""Test embed method with texts containing only whitespace."""
|
|
174
|
+
texts = [" ", "\n", "\t"]
|
|
175
|
+
result = await test_adapter.generate_embeddings(texts)
|
|
176
|
+
assert len(result.embeddings) == 3
|
|
177
|
+
# Should still generate embeddings for whitespace-only texts
|
|
178
|
+
assert result.embeddings[0].vector == [0.1, 0.1, 0.1]
|
|
179
|
+
assert result.embeddings[1].vector == [0.2, 0.2, 0.2]
|
|
180
|
+
assert result.embeddings[2].vector == pytest.approx([0.3, 0.3, 0.3])
|
|
181
|
+
|
|
182
|
+
async def test_generate_embeddings_with_duplicate_texts(self, test_adapter):
|
|
183
|
+
"""Test embed method with duplicate texts."""
|
|
184
|
+
texts = ["hello", "hello", "world"]
|
|
185
|
+
result = await test_adapter.generate_embeddings(texts)
|
|
186
|
+
assert len(result.embeddings) == 3
|
|
187
|
+
# Each text should get its own embedding, even if duplicate
|
|
188
|
+
assert result.embeddings[0].vector == [0.1, 0.1, 0.1]
|
|
189
|
+
assert result.embeddings[1].vector == [0.2, 0.2, 0.2]
|
|
190
|
+
assert result.embeddings[2].vector == pytest.approx([0.3, 0.3, 0.3])
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class TestBaseEmbeddingAdapterEdgeCases:
|
|
194
|
+
"""Test edge cases for BaseEmbeddingAdapter."""
|
|
195
|
+
|
|
196
|
+
async def test_generate_embeddings_with_very_long_text(self, test_adapter):
|
|
197
|
+
"""Test embed method with very long text."""
|
|
198
|
+
long_text = "a" * 10000
|
|
199
|
+
result = await test_adapter.generate_embeddings([long_text])
|
|
200
|
+
assert len(result.embeddings) == 1
|
|
201
|
+
assert result.embeddings[0].vector == [0.1, 0.1, 0.1]
|
|
202
|
+
|
|
203
|
+
async def test_generate_embeddings_with_special_characters(self, test_adapter):
|
|
204
|
+
"""Test embed method with special characters."""
|
|
205
|
+
texts = ["hello\nworld", "test\twith\ttabs", "unicode: 🚀🌟"]
|
|
206
|
+
result = await test_adapter.generate_embeddings(texts)
|
|
207
|
+
assert len(result.embeddings) == 3
|
|
208
|
+
assert result.embeddings[0].vector == [0.1, 0.1, 0.1]
|
|
209
|
+
assert result.embeddings[1].vector == [0.2, 0.2, 0.2]
|
|
210
|
+
assert result.embeddings[2].vector == pytest.approx([0.3, 0.3, 0.3])
|
|
211
|
+
|
|
212
|
+
async def test_generate_embeddings_with_empty_strings(self, test_adapter):
|
|
213
|
+
"""Test embed method with empty strings."""
|
|
214
|
+
texts = ["", "", ""]
|
|
215
|
+
result = await test_adapter.generate_embeddings(texts)
|
|
216
|
+
assert len(result.embeddings) == 3
|
|
217
|
+
assert result.embeddings[0].vector == [0.1, 0.1, 0.1]
|
|
218
|
+
assert result.embeddings[1].vector == [0.2, 0.2, 0.2]
|
|
219
|
+
assert result.embeddings[2].vector == pytest.approx([0.3, 0.3, 0.3])
|
|
220
|
+
|
|
221
|
+
def test_embedding_config_properties(self, mock_embedding_config):
|
|
222
|
+
"""Test that embedding config properties are accessible."""
|
|
223
|
+
mock_embedding_config.properties = {"dimensions": 1536, "normalize": True}
|
|
224
|
+
adapter = MockEmbeddingAdapter(mock_embedding_config)
|
|
225
|
+
assert adapter.embedding_config.properties["dimensions"] == 1536
|
|
226
|
+
assert adapter.embedding_config.properties["normalize"] is True
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class TestBaseEmbeddingAdapterIntegration:
|
|
230
|
+
"""Integration tests for BaseEmbeddingAdapter."""
|
|
231
|
+
|
|
232
|
+
async def test_generate_embeddings_method_calls_abstract_method(
|
|
233
|
+
self, mock_embedding_config
|
|
234
|
+
):
|
|
235
|
+
"""Test that embed method properly calls the abstract _embed method."""
|
|
236
|
+
|
|
237
|
+
# Create a mock adapter that tracks if _embed was called
|
|
238
|
+
class MockAdapter(BaseEmbeddingAdapter):
|
|
239
|
+
def __init__(self, config):
|
|
240
|
+
super().__init__(config)
|
|
241
|
+
self._generate_embeddings_called = False
|
|
242
|
+
self._generate_embeddings_args = None
|
|
243
|
+
|
|
244
|
+
async def _generate_embeddings(
|
|
245
|
+
self, text_inputs: list[str]
|
|
246
|
+
) -> EmbeddingResult:
|
|
247
|
+
self._generate_embeddings_called = True
|
|
248
|
+
self._generate_embeddings_args = text_inputs
|
|
249
|
+
return EmbeddingResult(embeddings=[])
|
|
250
|
+
|
|
251
|
+
adapter = MockAdapter(mock_embedding_config)
|
|
252
|
+
texts = ["hello", "world"]
|
|
253
|
+
|
|
254
|
+
result = await adapter.generate_embeddings(texts)
|
|
255
|
+
|
|
256
|
+
assert adapter._generate_embeddings_called
|
|
257
|
+
assert adapter._generate_embeddings_args == texts
|
|
258
|
+
assert result.embeddings == []
|
|
259
|
+
assert result.usage is None
|
|
260
|
+
|
|
261
|
+
async def test_generate_embeddings_empty_list_does_not_call_abstract_method(
|
|
262
|
+
self, mock_embedding_config
|
|
263
|
+
):
|
|
264
|
+
"""Test that embed method with empty list does not call _embed."""
|
|
265
|
+
|
|
266
|
+
class MockAdapter(BaseEmbeddingAdapter):
|
|
267
|
+
def __init__(self, config):
|
|
268
|
+
super().__init__(config)
|
|
269
|
+
self._generate_embeddings_called = False
|
|
270
|
+
|
|
271
|
+
async def _generate_embeddings(
|
|
272
|
+
self, text_inputs: list[str]
|
|
273
|
+
) -> EmbeddingResult:
|
|
274
|
+
self._generate_embeddings_called = True
|
|
275
|
+
return EmbeddingResult(embeddings=[])
|
|
276
|
+
|
|
277
|
+
adapter = MockAdapter(mock_embedding_config)
|
|
278
|
+
|
|
279
|
+
result = await adapter.generate_embeddings([])
|
|
280
|
+
|
|
281
|
+
assert not adapter._generate_embeddings_called
|
|
282
|
+
assert result.embeddings == []
|
|
283
|
+
assert result.usage is None
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
from unittest.mock import patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from kiln_ai.adapters.embedding.embedding_registry import embedding_adapter_from_type
|
|
6
|
+
from kiln_ai.adapters.embedding.litellm_embedding_adapter import LitellmEmbeddingAdapter
|
|
7
|
+
from kiln_ai.adapters.ml_model_list import ModelProviderName
|
|
8
|
+
from kiln_ai.adapters.provider_tools import LiteLlmCoreConfig
|
|
9
|
+
from kiln_ai.datamodel.embedding import EmbeddingConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@pytest.fixture
|
|
13
|
+
def mock_provider_configs():
|
|
14
|
+
with patch("kiln_ai.utils.config.Config.shared") as mock_config:
|
|
15
|
+
mock_config.return_value.open_ai_api_key = "test-openai-key"
|
|
16
|
+
mock_config.return_value.gemini_api_key = "test-gemini-key"
|
|
17
|
+
yield mock_config
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def test_embedding_adapter_from_type(mock_provider_configs):
|
|
21
|
+
"""Test basic embedding adapter creation with valid config."""
|
|
22
|
+
embedding_config = EmbeddingConfig(
|
|
23
|
+
name="test-embedding",
|
|
24
|
+
model_provider_name=ModelProviderName.gemini_api,
|
|
25
|
+
model_name="text-embedding-003",
|
|
26
|
+
properties={"dimensions": 768},
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
adapter = embedding_adapter_from_type(embedding_config)
|
|
30
|
+
|
|
31
|
+
assert isinstance(adapter, LitellmEmbeddingAdapter)
|
|
32
|
+
assert adapter.embedding_config.model_name == "text-embedding-003"
|
|
33
|
+
assert adapter.embedding_config.model_provider_name == ModelProviderName.gemini_api
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@patch(
|
|
37
|
+
"kiln_ai.adapters.embedding.embedding_registry.lite_llm_core_config_for_provider"
|
|
38
|
+
)
|
|
39
|
+
def test_embedding_adapter_from_type_uses_litellm_core_config(
|
|
40
|
+
mock_get_litellm_core_config,
|
|
41
|
+
):
|
|
42
|
+
"""Test that embedding adapter receives auth details from provider_tools."""
|
|
43
|
+
mock_litellm_core_config = LiteLlmCoreConfig(
|
|
44
|
+
base_url="https://test.com",
|
|
45
|
+
additional_body_options={"api_key": "test-key"},
|
|
46
|
+
default_headers={},
|
|
47
|
+
)
|
|
48
|
+
mock_get_litellm_core_config.return_value = mock_litellm_core_config
|
|
49
|
+
|
|
50
|
+
embedding_config = EmbeddingConfig(
|
|
51
|
+
name="test-embedding",
|
|
52
|
+
model_provider_name=ModelProviderName.openai,
|
|
53
|
+
model_name="text-embedding-3-small",
|
|
54
|
+
properties={"dimensions": 1536},
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
adapter = embedding_adapter_from_type(embedding_config)
|
|
58
|
+
|
|
59
|
+
assert isinstance(adapter, LitellmEmbeddingAdapter)
|
|
60
|
+
assert adapter.litellm_core_config == mock_litellm_core_config
|
|
61
|
+
mock_get_litellm_core_config.assert_called_once()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def test_embedding_adapter_from_type_invalid_provider():
|
|
65
|
+
"""Test that invalid model provider names raise a clear error."""
|
|
66
|
+
# Create a valid config first, then test the enum conversion logic
|
|
67
|
+
embedding_config = EmbeddingConfig(
|
|
68
|
+
name="test-embedding",
|
|
69
|
+
model_provider_name=ModelProviderName.openai,
|
|
70
|
+
model_name="some-model",
|
|
71
|
+
properties={"dimensions": 768},
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Mock the ModelProviderName constructor to simulate an invalid provider
|
|
75
|
+
with patch(
|
|
76
|
+
"kiln_ai.adapters.embedding.embedding_registry.ModelProviderName"
|
|
77
|
+
) as mock_enum:
|
|
78
|
+
mock_enum.side_effect = ValueError("Invalid provider")
|
|
79
|
+
|
|
80
|
+
with pytest.raises(
|
|
81
|
+
ValueError,
|
|
82
|
+
match="Unsupported model provider name: openai",
|
|
83
|
+
):
|
|
84
|
+
embedding_adapter_from_type(embedding_config)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def test_embedding_adapter_from_type_no_config_found(mock_provider_configs):
|
|
88
|
+
"""Test that missing provider configuration raises an error."""
|
|
89
|
+
with patch(
|
|
90
|
+
"kiln_ai.adapters.embedding.embedding_registry.lite_llm_core_config_for_provider"
|
|
91
|
+
) as mock_lite_llm_core_config_for_provider:
|
|
92
|
+
mock_lite_llm_core_config_for_provider.return_value = None
|
|
93
|
+
|
|
94
|
+
embedding_config = EmbeddingConfig(
|
|
95
|
+
name="test-embedding",
|
|
96
|
+
model_provider_name=ModelProviderName.openai,
|
|
97
|
+
model_name="text-embedding-3-small",
|
|
98
|
+
properties={"dimensions": 1536},
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
with pytest.raises(
|
|
102
|
+
ValueError, match="No configuration found for core provider:"
|
|
103
|
+
):
|
|
104
|
+
embedding_adapter_from_type(embedding_config)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@pytest.mark.parametrize(
|
|
108
|
+
"provider_name",
|
|
109
|
+
[
|
|
110
|
+
ModelProviderName.openai,
|
|
111
|
+
ModelProviderName.gemini_api,
|
|
112
|
+
],
|
|
113
|
+
)
|
|
114
|
+
def test_embedding_adapter_from_type_different_providers(
|
|
115
|
+
provider_name, mock_provider_configs
|
|
116
|
+
):
|
|
117
|
+
"""Test that different providers work correctly."""
|
|
118
|
+
embedding_config = EmbeddingConfig(
|
|
119
|
+
name="test-embedding",
|
|
120
|
+
model_provider_name=provider_name,
|
|
121
|
+
model_name="test-model",
|
|
122
|
+
properties={"dimensions": 768},
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
adapter = embedding_adapter_from_type(embedding_config)
|
|
126
|
+
|
|
127
|
+
assert isinstance(adapter, LitellmEmbeddingAdapter)
|
|
128
|
+
assert adapter.embedding_config.model_provider_name == provider_name
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def test_embedding_adapter_from_type_with_description(mock_provider_configs):
|
|
132
|
+
"""Test embedding adapter creation with description."""
|
|
133
|
+
embedding_config = EmbeddingConfig(
|
|
134
|
+
name="test-embedding",
|
|
135
|
+
description="Test embedding configuration",
|
|
136
|
+
model_provider_name=ModelProviderName.openai,
|
|
137
|
+
model_name="text-embedding-3-small",
|
|
138
|
+
properties={"dimensions": 1536},
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
adapter = embedding_adapter_from_type(embedding_config)
|
|
142
|
+
|
|
143
|
+
assert isinstance(adapter, LitellmEmbeddingAdapter)
|
|
144
|
+
assert adapter.embedding_config.description == "Test embedding configuration"
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def test_embedding_adapter_from_type_with_additional_properties(
|
|
148
|
+
mock_provider_configs,
|
|
149
|
+
):
|
|
150
|
+
"""Test embedding adapter creation with additional properties."""
|
|
151
|
+
embedding_config = EmbeddingConfig(
|
|
152
|
+
name="test-embedding",
|
|
153
|
+
model_provider_name=ModelProviderName.openai,
|
|
154
|
+
model_name="text-embedding-3-small",
|
|
155
|
+
properties={
|
|
156
|
+
"dimensions": 1536,
|
|
157
|
+
"batch_size": 100,
|
|
158
|
+
"max_retries": 3,
|
|
159
|
+
},
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
adapter = embedding_adapter_from_type(embedding_config)
|
|
163
|
+
|
|
164
|
+
assert isinstance(adapter, LitellmEmbeddingAdapter)
|
|
165
|
+
assert adapter.embedding_config.properties["batch_size"] == 100
|
|
166
|
+
assert adapter.embedding_config.properties["max_retries"] == 3
|