kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (133) hide show
  1. kiln_ai/adapters/__init__.py +6 -0
  2. kiln_ai/adapters/adapter_registry.py +43 -226
  3. kiln_ai/adapters/chunkers/__init__.py +13 -0
  4. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  5. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  6. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  7. kiln_ai/adapters/chunkers/helpers.py +23 -0
  8. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  9. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  10. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  11. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  12. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  13. kiln_ai/adapters/embedding/__init__.py +0 -0
  14. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  15. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  16. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  17. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  18. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  19. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  20. kiln_ai/adapters/eval/eval_runner.py +6 -2
  21. kiln_ai/adapters/eval/test_base_eval.py +1 -3
  22. kiln_ai/adapters/eval/test_g_eval.py +1 -1
  23. kiln_ai/adapters/extractors/__init__.py +18 -0
  24. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  25. kiln_ai/adapters/extractors/encoding.py +20 -0
  26. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  27. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  28. kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
  29. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  30. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  31. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  32. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  33. kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
  34. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  35. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  36. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  37. kiln_ai/adapters/ml_embedding_model_list.py +494 -0
  38. kiln_ai/adapters/ml_model_list.py +876 -18
  39. kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
  40. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
  41. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
  42. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
  43. kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
  44. kiln_ai/adapters/ollama_tools.py +69 -12
  45. kiln_ai/adapters/provider_tools.py +190 -46
  46. kiln_ai/adapters/rag/deduplication.py +49 -0
  47. kiln_ai/adapters/rag/progress.py +252 -0
  48. kiln_ai/adapters/rag/rag_runners.py +844 -0
  49. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  50. kiln_ai/adapters/rag/test_progress.py +785 -0
  51. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  52. kiln_ai/adapters/remote_config.py +80 -8
  53. kiln_ai/adapters/test_adapter_registry.py +579 -86
  54. kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
  55. kiln_ai/adapters/test_ml_model_list.py +202 -0
  56. kiln_ai/adapters/test_ollama_tools.py +340 -1
  57. kiln_ai/adapters/test_prompt_builders.py +1 -1
  58. kiln_ai/adapters/test_provider_tools.py +199 -8
  59. kiln_ai/adapters/test_remote_config.py +551 -56
  60. kiln_ai/adapters/vector_store/__init__.py +1 -0
  61. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  62. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  63. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  64. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  65. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  66. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  67. kiln_ai/datamodel/__init__.py +16 -13
  68. kiln_ai/datamodel/basemodel.py +201 -4
  69. kiln_ai/datamodel/chunk.py +158 -0
  70. kiln_ai/datamodel/datamodel_enums.py +27 -0
  71. kiln_ai/datamodel/embedding.py +64 -0
  72. kiln_ai/datamodel/external_tool_server.py +206 -54
  73. kiln_ai/datamodel/extraction.py +317 -0
  74. kiln_ai/datamodel/project.py +33 -1
  75. kiln_ai/datamodel/rag.py +79 -0
  76. kiln_ai/datamodel/task.py +5 -0
  77. kiln_ai/datamodel/task_output.py +41 -11
  78. kiln_ai/datamodel/test_attachment.py +649 -0
  79. kiln_ai/datamodel/test_basemodel.py +270 -14
  80. kiln_ai/datamodel/test_chunk_models.py +317 -0
  81. kiln_ai/datamodel/test_dataset_split.py +1 -1
  82. kiln_ai/datamodel/test_datasource.py +50 -0
  83. kiln_ai/datamodel/test_embedding_models.py +448 -0
  84. kiln_ai/datamodel/test_eval_model.py +6 -6
  85. kiln_ai/datamodel/test_external_tool_server.py +534 -152
  86. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  87. kiln_ai/datamodel/test_extraction_model.py +501 -0
  88. kiln_ai/datamodel/test_rag.py +641 -0
  89. kiln_ai/datamodel/test_task.py +35 -1
  90. kiln_ai/datamodel/test_tool_id.py +187 -1
  91. kiln_ai/datamodel/test_vector_store.py +320 -0
  92. kiln_ai/datamodel/tool_id.py +58 -0
  93. kiln_ai/datamodel/vector_store.py +141 -0
  94. kiln_ai/tools/base_tool.py +12 -3
  95. kiln_ai/tools/built_in_tools/math_tools.py +12 -4
  96. kiln_ai/tools/kiln_task_tool.py +158 -0
  97. kiln_ai/tools/mcp_server_tool.py +2 -2
  98. kiln_ai/tools/mcp_session_manager.py +51 -22
  99. kiln_ai/tools/rag_tools.py +164 -0
  100. kiln_ai/tools/test_kiln_task_tool.py +527 -0
  101. kiln_ai/tools/test_mcp_server_tool.py +4 -15
  102. kiln_ai/tools/test_mcp_session_manager.py +187 -227
  103. kiln_ai/tools/test_rag_tools.py +929 -0
  104. kiln_ai/tools/test_tool_registry.py +290 -7
  105. kiln_ai/tools/tool_registry.py +69 -16
  106. kiln_ai/utils/__init__.py +3 -0
  107. kiln_ai/utils/async_job_runner.py +62 -17
  108. kiln_ai/utils/config.py +2 -2
  109. kiln_ai/utils/env.py +15 -0
  110. kiln_ai/utils/filesystem.py +14 -0
  111. kiln_ai/utils/filesystem_cache.py +60 -0
  112. kiln_ai/utils/litellm.py +94 -0
  113. kiln_ai/utils/lock.py +100 -0
  114. kiln_ai/utils/mime_type.py +38 -0
  115. kiln_ai/utils/open_ai_types.py +19 -2
  116. kiln_ai/utils/pdf_utils.py +59 -0
  117. kiln_ai/utils/test_async_job_runner.py +151 -35
  118. kiln_ai/utils/test_env.py +142 -0
  119. kiln_ai/utils/test_filesystem_cache.py +316 -0
  120. kiln_ai/utils/test_litellm.py +206 -0
  121. kiln_ai/utils/test_lock.py +185 -0
  122. kiln_ai/utils/test_mime_type.py +66 -0
  123. kiln_ai/utils/test_open_ai_types.py +88 -12
  124. kiln_ai/utils/test_pdf_utils.py +86 -0
  125. kiln_ai/utils/test_uuid.py +111 -0
  126. kiln_ai/utils/test_validation.py +524 -0
  127. kiln_ai/utils/uuid.py +9 -0
  128. kiln_ai/utils/validation.py +90 -0
  129. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
  130. kiln_ai-0.22.0.dist-info/RECORD +213 -0
  131. kiln_ai-0.20.1.dist-info/RECORD +0 -138
  132. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
  133. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,199 @@
1
+ from functools import cached_property
2
+ from typing import Any, Dict, List, Tuple
3
+
4
+ import litellm
5
+ from litellm.types.utils import EmbeddingResponse
6
+ from pydantic import BaseModel, Field
7
+
8
+ from kiln_ai.adapters.embedding.base_embedding_adapter import (
9
+ BaseEmbeddingAdapter,
10
+ Embedding,
11
+ EmbeddingResult,
12
+ )
13
+ from kiln_ai.adapters.ml_embedding_model_list import (
14
+ KilnEmbeddingModelProvider,
15
+ built_in_embedding_models_from_provider,
16
+ )
17
+ from kiln_ai.adapters.provider_tools import LiteLlmCoreConfig
18
+ from kiln_ai.datamodel.embedding import EmbeddingConfig
19
+ from kiln_ai.utils.litellm import get_litellm_provider_info
20
+
21
+ # litellm enforces a limit, documented here:
22
+ # https://docs.litellm.ai/docs/embedding/supported_embedding
23
+ # but some providers impose lower limits that LiteLLM does not know about
24
+ # for example, Gemini currently has a limit of 100 inputs per request
25
+ MAX_BATCH_SIZE = 100
26
+
27
+
28
+ class EmbeddingOptions(BaseModel):
29
+ dimensions: int | None = Field(
30
+ default=None,
31
+ description="The number of dimensions to return for embeddings. Some models support requesting vectors of different dimensions.",
32
+ )
33
+
34
+
35
+ def validate_map_to_embeddings(
36
+ response: EmbeddingResponse,
37
+ expected_embedding_count: int,
38
+ ) -> List[Embedding]:
39
+ # LiteLLM has an Embedding type in litellm.types.utils, but the EmbeddingResponse data has a list of untyped dicts,
40
+ # which can be dangerous especially if we upgrade litellm, so we do some sanity checks here
41
+ if not isinstance(response, EmbeddingResponse):
42
+ raise RuntimeError(f"Expected EmbeddingResponse, got {type(response)}.")
43
+
44
+ list_to_validate = response.data
45
+ if len(list_to_validate) != expected_embedding_count:
46
+ raise RuntimeError(
47
+ f"Expected the number of embeddings in the response to be {expected_embedding_count}, got {len(list_to_validate)}."
48
+ )
49
+
50
+ validated_vectors: List[Tuple[list[float], int]] = []
51
+ for embedding_dict in list_to_validate:
52
+ object_type = embedding_dict.get("object")
53
+ if object_type != "embedding":
54
+ raise RuntimeError(
55
+ f"Embedding response data has an unexpected shape. Property 'object' is not 'embedding'. Got {object_type}."
56
+ )
57
+
58
+ embedding_property_value = embedding_dict.get("embedding")
59
+ if embedding_property_value is None:
60
+ raise RuntimeError(
61
+ "Embedding response data has an unexpected shape. Property 'embedding' is None in response data item."
62
+ )
63
+ if not isinstance(embedding_property_value, list):
64
+ raise RuntimeError(
65
+ f"Embedding response data has an unexpected shape. Property 'embedding' is not a list. Got {type(embedding_property_value)}."
66
+ )
67
+
68
+ index_property_value = embedding_dict.get("index")
69
+ if index_property_value is None:
70
+ raise RuntimeError(
71
+ "Embedding response data has an unexpected shape. Property 'index' is None in response data item."
72
+ )
73
+ if not isinstance(index_property_value, int):
74
+ raise RuntimeError(
75
+ f"Embedding response data has an unexpected shape. Property 'index' is not an integer. Got {type(index_property_value)}."
76
+ )
77
+
78
+ validated_vectors.append((embedding_property_value, index_property_value))
79
+
80
+ # sort by index, in place - the data should already be sorted by index,
81
+ # but litellm docs are not explicit about this
82
+ validated_vectors.sort(key=lambda x: x[1])
83
+
84
+ return [
85
+ Embedding(vector=embedding_vector) for embedding_vector, _ in validated_vectors
86
+ ]
87
+
88
+
89
+ class LitellmEmbeddingAdapter(BaseEmbeddingAdapter):
90
+ def __init__(
91
+ self, embedding_config: EmbeddingConfig, litellm_core_config: LiteLlmCoreConfig
92
+ ):
93
+ super().__init__(embedding_config)
94
+
95
+ self.litellm_core_config = litellm_core_config
96
+
97
+ async def _generate_embeddings(self, input_texts: List[str]) -> EmbeddingResult:
98
+ # batch the requests
99
+ batches: List[List[str]] = []
100
+ for i in range(0, len(input_texts), MAX_BATCH_SIZE):
101
+ batches.append(input_texts[i : i + MAX_BATCH_SIZE])
102
+
103
+ # generate embeddings for each batch
104
+ results: List[EmbeddingResult] = []
105
+ for batch in batches:
106
+ batch_response = await self._generate_embeddings_for_batch(batch)
107
+ results.append(batch_response)
108
+
109
+ # merge the results
110
+ combined_embeddings: List[Embedding] = []
111
+ combined_usage = None
112
+
113
+ # we prefer returning None overall usage if any of the results is missing usage
114
+ # better than returning a misleading usage
115
+ all_have_usage = all(result.usage is not None for result in results)
116
+ if all_have_usage:
117
+ combined_usage = litellm.Usage(
118
+ prompt_tokens=0, total_tokens=0, completion_tokens=0
119
+ )
120
+ for result in results:
121
+ if result.usage is not None:
122
+ combined_usage.prompt_tokens += result.usage.prompt_tokens
123
+ combined_usage.total_tokens += result.usage.total_tokens
124
+ combined_usage.completion_tokens += result.usage.completion_tokens
125
+
126
+ for result in results:
127
+ combined_embeddings.extend(result.embeddings)
128
+
129
+ return EmbeddingResult(
130
+ embeddings=combined_embeddings,
131
+ usage=combined_usage,
132
+ )
133
+
134
+ async def _generate_embeddings_for_batch(
135
+ self, input_texts: List[str]
136
+ ) -> EmbeddingResult:
137
+ if len(input_texts) > MAX_BATCH_SIZE:
138
+ raise ValueError(
139
+ f"Too many input texts, max batch size is {MAX_BATCH_SIZE}, got {len(input_texts)}"
140
+ )
141
+
142
+ completion_kwargs: Dict[str, Any] = {}
143
+ if self.litellm_core_config.additional_body_options:
144
+ completion_kwargs.update(self.litellm_core_config.additional_body_options)
145
+
146
+ if self.litellm_core_config.base_url:
147
+ completion_kwargs["api_base"] = self.litellm_core_config.base_url
148
+
149
+ if self.litellm_core_config.default_headers:
150
+ completion_kwargs["default_headers"] = (
151
+ self.litellm_core_config.default_headers
152
+ )
153
+
154
+ response = await litellm.aembedding(
155
+ model=self.litellm_model_id,
156
+ input=input_texts,
157
+ **self.build_options().model_dump(exclude_none=True),
158
+ **completion_kwargs,
159
+ )
160
+
161
+ validated_embeddings = validate_map_to_embeddings(
162
+ response, expected_embedding_count=len(input_texts)
163
+ )
164
+
165
+ return EmbeddingResult(
166
+ embeddings=validated_embeddings,
167
+ usage=response.usage,
168
+ )
169
+
170
+ def build_options(self) -> EmbeddingOptions:
171
+ dimensions = self.embedding_config.properties.get("dimensions", None)
172
+ if dimensions is not None:
173
+ if not isinstance(dimensions, int) or dimensions <= 0:
174
+ raise ValueError("Dimensions must be a positive integer")
175
+
176
+ return EmbeddingOptions(
177
+ dimensions=dimensions,
178
+ )
179
+
180
+ @cached_property
181
+ def model_provider(self) -> KilnEmbeddingModelProvider:
182
+ provider = built_in_embedding_models_from_provider(
183
+ self.embedding_config.model_provider_name, self.embedding_config.model_name
184
+ )
185
+ if provider is None:
186
+ raise ValueError(
187
+ f"Embedding model {self.embedding_config.model_name} not found in the list of built-in models"
188
+ )
189
+ return provider
190
+
191
+ @cached_property
192
+ def litellm_model_id(self) -> str:
193
+ provider_info = get_litellm_provider_info(self.model_provider)
194
+ if provider_info.is_custom and self.litellm_core_config.base_url is None:
195
+ raise ValueError(
196
+ f"Provider {self.model_provider.name.value} must have an explicit base URL"
197
+ )
198
+
199
+ return provider_info.litellm_model_id
@@ -0,0 +1,283 @@
1
+ import pytest
2
+
3
+ from kiln_ai.adapters.embedding.base_embedding_adapter import (
4
+ BaseEmbeddingAdapter,
5
+ Embedding,
6
+ EmbeddingResult,
7
+ )
8
+ from kiln_ai.datamodel.datamodel_enums import ModelProviderName
9
+ from kiln_ai.datamodel.embedding import EmbeddingConfig
10
+
11
+
12
+ class MockEmbeddingAdapter(BaseEmbeddingAdapter):
13
+ """Concrete implementation of BaseEmbeddingAdapter for testing purposes."""
14
+
15
+ async def _generate_embeddings(self, text_inputs: list[str]) -> EmbeddingResult:
16
+ # Simple test implementation that returns mock embeddings
17
+ embeddings = []
18
+ for i, _ in enumerate(text_inputs):
19
+ embeddings.append(Embedding(vector=[0.1 * (i + 1)] * 3))
20
+
21
+ return EmbeddingResult(embeddings=embeddings)
22
+
23
+
24
+ class MockEmbeddingAdapterWithUsage(BaseEmbeddingAdapter):
25
+ """Concrete implementation that includes usage information."""
26
+
27
+ async def _generate_embeddings(self, text_inputs: list[str]) -> EmbeddingResult:
28
+ from litellm import Usage
29
+
30
+ embeddings = []
31
+ for i, _ in enumerate(text_inputs):
32
+ embeddings.append(Embedding(vector=[0.1 * (i + 1)] * 3))
33
+
34
+ usage = Usage(
35
+ prompt_tokens=len(text_inputs) * 10, total_tokens=len(text_inputs) * 10
36
+ )
37
+ return EmbeddingResult(embeddings=embeddings, usage=usage)
38
+
39
+
40
+ @pytest.fixture
41
+ def mock_embedding_config():
42
+ """Create a mock embedding config for testing."""
43
+ return EmbeddingConfig(
44
+ name="test-embedding",
45
+ model_provider_name=ModelProviderName.openai,
46
+ model_name="openai_text_embedding_3_small",
47
+ properties={},
48
+ )
49
+
50
+
51
+ @pytest.fixture
52
+ def test_adapter(mock_embedding_config):
53
+ """Create a test adapter instance."""
54
+ return MockEmbeddingAdapter(mock_embedding_config)
55
+
56
+
57
+ @pytest.fixture
58
+ def test_adapter_with_usage(mock_embedding_config):
59
+ """Create a test adapter instance that includes usage information."""
60
+ return MockEmbeddingAdapterWithUsage(mock_embedding_config)
61
+
62
+
63
+ class TestEmbedding:
64
+ """Test the Embedding model."""
65
+
66
+ def test_creation(self):
67
+ """Test creating a Embedding with a vector."""
68
+ vector = [0.1, 0.2, 0.3]
69
+ embedding = Embedding(vector=vector)
70
+ assert embedding.vector == vector
71
+
72
+ def test_empty_vector(self):
73
+ """Test creating a Embedding with an empty vector."""
74
+ embedding = Embedding(vector=[])
75
+ assert embedding.vector == []
76
+
77
+ def test_large_vector(self):
78
+ """Test creating a Embedding with a large vector."""
79
+ vector = [0.1] * 1536
80
+ embedding = Embedding(vector=vector)
81
+ assert len(embedding.vector) == 1536
82
+ assert all(v == 0.1 for v in embedding.vector)
83
+
84
+
85
+ class TestEmbeddingResult:
86
+ """Test the EmbeddingResult model."""
87
+
88
+ def test_creation_with_embeddings(self):
89
+ """Test creating an EmbeddingResult with embeddings."""
90
+ embeddings = [
91
+ Embedding(vector=[0.1, 0.2, 0.3]),
92
+ Embedding(vector=[0.4, 0.5, 0.6]),
93
+ ]
94
+ result = EmbeddingResult(embeddings=embeddings)
95
+ assert result.embeddings == embeddings
96
+ assert result.usage is None
97
+
98
+ def test_creation_with_usage(self):
99
+ """Test creating an EmbeddingResult with usage information."""
100
+ from litellm import Usage
101
+
102
+ embeddings = [Embedding(vector=[0.1, 0.2, 0.3])]
103
+ usage = Usage(prompt_tokens=10, total_tokens=10)
104
+ result = EmbeddingResult(embeddings=embeddings, usage=usage)
105
+ assert result.embeddings == embeddings
106
+ assert result.usage == usage
107
+
108
+ def test_empty_embeddings(self):
109
+ """Test creating an EmbeddingResult with empty embeddings."""
110
+ result = EmbeddingResult(embeddings=[])
111
+ assert result.embeddings == []
112
+ assert result.usage is None
113
+
114
+
115
+ class TestBaseEmbeddingAdapter:
116
+ """Test the BaseEmbeddingAdapter abstract base class."""
117
+
118
+ def test_init(self, mock_embedding_config):
119
+ """Test successful initialization of the adapter."""
120
+ adapter = MockEmbeddingAdapter(mock_embedding_config)
121
+ assert adapter.embedding_config == mock_embedding_config
122
+ assert adapter.embedding_config.name == "test-embedding"
123
+ assert adapter.embedding_config.model_provider_name == ModelProviderName.openai
124
+ assert adapter.embedding_config.model_name == "openai_text_embedding_3_small"
125
+ assert adapter.embedding_config.properties == {}
126
+
127
+ def test_cannot_instantiate_abstract_class(self, mock_embedding_config):
128
+ """Test that BaseEmbeddingAdapter cannot be instantiated directly."""
129
+ with pytest.raises(TypeError):
130
+ BaseEmbeddingAdapter(mock_embedding_config)
131
+
132
+ async def test_generate_embeddings_empty_list(self, test_adapter):
133
+ """Test embed method with empty text list."""
134
+ result = await test_adapter.generate_embeddings([])
135
+ assert result.embeddings == []
136
+ assert result.usage is None
137
+
138
+ async def test_generate_embeddings_single_text(self, test_adapter):
139
+ """Test embed method with a single text."""
140
+ result = await test_adapter.generate_embeddings(["hello world"])
141
+ assert len(result.embeddings) == 1
142
+ assert result.embeddings[0].vector == [0.1, 0.1, 0.1]
143
+ assert result.usage is None
144
+
145
+ async def test_generate_embeddings_multiple_texts(self, test_adapter):
146
+ """Test embed method with multiple texts."""
147
+ texts = ["hello world", "my name is john", "i like to eat apples"]
148
+ result = await test_adapter.generate_embeddings(texts)
149
+ assert len(result.embeddings) == 3
150
+ assert result.embeddings[0].vector == [0.1, 0.1, 0.1]
151
+ assert result.embeddings[1].vector == [0.2, 0.2, 0.2]
152
+ assert result.embeddings[2].vector == pytest.approx([0.3, 0.3, 0.3])
153
+ assert result.usage is None
154
+
155
+ async def test_generate_embeddings_with_usage(self, test_adapter_with_usage):
156
+ """Test embed method with usage information."""
157
+ texts = ["hello", "world"]
158
+ result = await test_adapter_with_usage.generate_embeddings(texts)
159
+ assert len(result.embeddings) == 2
160
+ assert result.embeddings[0].vector == [0.1, 0.1, 0.1]
161
+ assert result.embeddings[1].vector == [0.2, 0.2, 0.2]
162
+ assert result.usage is not None
163
+ assert result.usage.prompt_tokens == 20
164
+ assert result.usage.total_tokens == 20
165
+
166
+ async def test_generate_embeddings_with_none_input(self, test_adapter):
167
+ """Test embed method with None input (should be treated as empty list)."""
168
+ result = await test_adapter.generate_embeddings(None) # type: ignore
169
+ assert result.embeddings == []
170
+ assert result.usage is None
171
+
172
+ async def test_generate_embeddings_with_whitespace_only_texts(self, test_adapter):
173
+ """Test embed method with texts containing only whitespace."""
174
+ texts = [" ", "\n", "\t"]
175
+ result = await test_adapter.generate_embeddings(texts)
176
+ assert len(result.embeddings) == 3
177
+ # Should still generate embeddings for whitespace-only texts
178
+ assert result.embeddings[0].vector == [0.1, 0.1, 0.1]
179
+ assert result.embeddings[1].vector == [0.2, 0.2, 0.2]
180
+ assert result.embeddings[2].vector == pytest.approx([0.3, 0.3, 0.3])
181
+
182
+ async def test_generate_embeddings_with_duplicate_texts(self, test_adapter):
183
+ """Test embed method with duplicate texts."""
184
+ texts = ["hello", "hello", "world"]
185
+ result = await test_adapter.generate_embeddings(texts)
186
+ assert len(result.embeddings) == 3
187
+ # Each text should get its own embedding, even if duplicate
188
+ assert result.embeddings[0].vector == [0.1, 0.1, 0.1]
189
+ assert result.embeddings[1].vector == [0.2, 0.2, 0.2]
190
+ assert result.embeddings[2].vector == pytest.approx([0.3, 0.3, 0.3])
191
+
192
+
193
+ class TestBaseEmbeddingAdapterEdgeCases:
194
+ """Test edge cases for BaseEmbeddingAdapter."""
195
+
196
+ async def test_generate_embeddings_with_very_long_text(self, test_adapter):
197
+ """Test embed method with very long text."""
198
+ long_text = "a" * 10000
199
+ result = await test_adapter.generate_embeddings([long_text])
200
+ assert len(result.embeddings) == 1
201
+ assert result.embeddings[0].vector == [0.1, 0.1, 0.1]
202
+
203
+ async def test_generate_embeddings_with_special_characters(self, test_adapter):
204
+ """Test embed method with special characters."""
205
+ texts = ["hello\nworld", "test\twith\ttabs", "unicode: 🚀🌟"]
206
+ result = await test_adapter.generate_embeddings(texts)
207
+ assert len(result.embeddings) == 3
208
+ assert result.embeddings[0].vector == [0.1, 0.1, 0.1]
209
+ assert result.embeddings[1].vector == [0.2, 0.2, 0.2]
210
+ assert result.embeddings[2].vector == pytest.approx([0.3, 0.3, 0.3])
211
+
212
+ async def test_generate_embeddings_with_empty_strings(self, test_adapter):
213
+ """Test embed method with empty strings."""
214
+ texts = ["", "", ""]
215
+ result = await test_adapter.generate_embeddings(texts)
216
+ assert len(result.embeddings) == 3
217
+ assert result.embeddings[0].vector == [0.1, 0.1, 0.1]
218
+ assert result.embeddings[1].vector == [0.2, 0.2, 0.2]
219
+ assert result.embeddings[2].vector == pytest.approx([0.3, 0.3, 0.3])
220
+
221
+ def test_embedding_config_properties(self, mock_embedding_config):
222
+ """Test that embedding config properties are accessible."""
223
+ mock_embedding_config.properties = {"dimensions": 1536, "normalize": True}
224
+ adapter = MockEmbeddingAdapter(mock_embedding_config)
225
+ assert adapter.embedding_config.properties["dimensions"] == 1536
226
+ assert adapter.embedding_config.properties["normalize"] is True
227
+
228
+
229
+ class TestBaseEmbeddingAdapterIntegration:
230
+ """Integration tests for BaseEmbeddingAdapter."""
231
+
232
+ async def test_generate_embeddings_method_calls_abstract_method(
233
+ self, mock_embedding_config
234
+ ):
235
+ """Test that embed method properly calls the abstract _embed method."""
236
+
237
+ # Create a mock adapter that tracks if _embed was called
238
+ class MockAdapter(BaseEmbeddingAdapter):
239
+ def __init__(self, config):
240
+ super().__init__(config)
241
+ self._generate_embeddings_called = False
242
+ self._generate_embeddings_args = None
243
+
244
+ async def _generate_embeddings(
245
+ self, text_inputs: list[str]
246
+ ) -> EmbeddingResult:
247
+ self._generate_embeddings_called = True
248
+ self._generate_embeddings_args = text_inputs
249
+ return EmbeddingResult(embeddings=[])
250
+
251
+ adapter = MockAdapter(mock_embedding_config)
252
+ texts = ["hello", "world"]
253
+
254
+ result = await adapter.generate_embeddings(texts)
255
+
256
+ assert adapter._generate_embeddings_called
257
+ assert adapter._generate_embeddings_args == texts
258
+ assert result.embeddings == []
259
+ assert result.usage is None
260
+
261
+ async def test_generate_embeddings_empty_list_does_not_call_abstract_method(
262
+ self, mock_embedding_config
263
+ ):
264
+ """Test that embed method with empty list does not call _embed."""
265
+
266
+ class MockAdapter(BaseEmbeddingAdapter):
267
+ def __init__(self, config):
268
+ super().__init__(config)
269
+ self._generate_embeddings_called = False
270
+
271
+ async def _generate_embeddings(
272
+ self, text_inputs: list[str]
273
+ ) -> EmbeddingResult:
274
+ self._generate_embeddings_called = True
275
+ return EmbeddingResult(embeddings=[])
276
+
277
+ adapter = MockAdapter(mock_embedding_config)
278
+
279
+ result = await adapter.generate_embeddings([])
280
+
281
+ assert not adapter._generate_embeddings_called
282
+ assert result.embeddings == []
283
+ assert result.usage is None
@@ -0,0 +1,166 @@
1
+ from unittest.mock import patch
2
+
3
+ import pytest
4
+
5
+ from kiln_ai.adapters.embedding.embedding_registry import embedding_adapter_from_type
6
+ from kiln_ai.adapters.embedding.litellm_embedding_adapter import LitellmEmbeddingAdapter
7
+ from kiln_ai.adapters.ml_model_list import ModelProviderName
8
+ from kiln_ai.adapters.provider_tools import LiteLlmCoreConfig
9
+ from kiln_ai.datamodel.embedding import EmbeddingConfig
10
+
11
+
12
+ @pytest.fixture
13
+ def mock_provider_configs():
14
+ with patch("kiln_ai.utils.config.Config.shared") as mock_config:
15
+ mock_config.return_value.open_ai_api_key = "test-openai-key"
16
+ mock_config.return_value.gemini_api_key = "test-gemini-key"
17
+ yield mock_config
18
+
19
+
20
+ def test_embedding_adapter_from_type(mock_provider_configs):
21
+ """Test basic embedding adapter creation with valid config."""
22
+ embedding_config = EmbeddingConfig(
23
+ name="test-embedding",
24
+ model_provider_name=ModelProviderName.gemini_api,
25
+ model_name="text-embedding-003",
26
+ properties={"dimensions": 768},
27
+ )
28
+
29
+ adapter = embedding_adapter_from_type(embedding_config)
30
+
31
+ assert isinstance(adapter, LitellmEmbeddingAdapter)
32
+ assert adapter.embedding_config.model_name == "text-embedding-003"
33
+ assert adapter.embedding_config.model_provider_name == ModelProviderName.gemini_api
34
+
35
+
36
+ @patch(
37
+ "kiln_ai.adapters.embedding.embedding_registry.lite_llm_core_config_for_provider"
38
+ )
39
+ def test_embedding_adapter_from_type_uses_litellm_core_config(
40
+ mock_get_litellm_core_config,
41
+ ):
42
+ """Test that embedding adapter receives auth details from provider_tools."""
43
+ mock_litellm_core_config = LiteLlmCoreConfig(
44
+ base_url="https://test.com",
45
+ additional_body_options={"api_key": "test-key"},
46
+ default_headers={},
47
+ )
48
+ mock_get_litellm_core_config.return_value = mock_litellm_core_config
49
+
50
+ embedding_config = EmbeddingConfig(
51
+ name="test-embedding",
52
+ model_provider_name=ModelProviderName.openai,
53
+ model_name="text-embedding-3-small",
54
+ properties={"dimensions": 1536},
55
+ )
56
+
57
+ adapter = embedding_adapter_from_type(embedding_config)
58
+
59
+ assert isinstance(adapter, LitellmEmbeddingAdapter)
60
+ assert adapter.litellm_core_config == mock_litellm_core_config
61
+ mock_get_litellm_core_config.assert_called_once()
62
+
63
+
64
+ def test_embedding_adapter_from_type_invalid_provider():
65
+ """Test that invalid model provider names raise a clear error."""
66
+ # Create a valid config first, then test the enum conversion logic
67
+ embedding_config = EmbeddingConfig(
68
+ name="test-embedding",
69
+ model_provider_name=ModelProviderName.openai,
70
+ model_name="some-model",
71
+ properties={"dimensions": 768},
72
+ )
73
+
74
+ # Mock the ModelProviderName constructor to simulate an invalid provider
75
+ with patch(
76
+ "kiln_ai.adapters.embedding.embedding_registry.ModelProviderName"
77
+ ) as mock_enum:
78
+ mock_enum.side_effect = ValueError("Invalid provider")
79
+
80
+ with pytest.raises(
81
+ ValueError,
82
+ match="Unsupported model provider name: openai",
83
+ ):
84
+ embedding_adapter_from_type(embedding_config)
85
+
86
+
87
+ def test_embedding_adapter_from_type_no_config_found(mock_provider_configs):
88
+ """Test that missing provider configuration raises an error."""
89
+ with patch(
90
+ "kiln_ai.adapters.embedding.embedding_registry.lite_llm_core_config_for_provider"
91
+ ) as mock_lite_llm_core_config_for_provider:
92
+ mock_lite_llm_core_config_for_provider.return_value = None
93
+
94
+ embedding_config = EmbeddingConfig(
95
+ name="test-embedding",
96
+ model_provider_name=ModelProviderName.openai,
97
+ model_name="text-embedding-3-small",
98
+ properties={"dimensions": 1536},
99
+ )
100
+
101
+ with pytest.raises(
102
+ ValueError, match="No configuration found for core provider:"
103
+ ):
104
+ embedding_adapter_from_type(embedding_config)
105
+
106
+
107
+ @pytest.mark.parametrize(
108
+ "provider_name",
109
+ [
110
+ ModelProviderName.openai,
111
+ ModelProviderName.gemini_api,
112
+ ],
113
+ )
114
+ def test_embedding_adapter_from_type_different_providers(
115
+ provider_name, mock_provider_configs
116
+ ):
117
+ """Test that different providers work correctly."""
118
+ embedding_config = EmbeddingConfig(
119
+ name="test-embedding",
120
+ model_provider_name=provider_name,
121
+ model_name="test-model",
122
+ properties={"dimensions": 768},
123
+ )
124
+
125
+ adapter = embedding_adapter_from_type(embedding_config)
126
+
127
+ assert isinstance(adapter, LitellmEmbeddingAdapter)
128
+ assert adapter.embedding_config.model_provider_name == provider_name
129
+
130
+
131
+ def test_embedding_adapter_from_type_with_description(mock_provider_configs):
132
+ """Test embedding adapter creation with description."""
133
+ embedding_config = EmbeddingConfig(
134
+ name="test-embedding",
135
+ description="Test embedding configuration",
136
+ model_provider_name=ModelProviderName.openai,
137
+ model_name="text-embedding-3-small",
138
+ properties={"dimensions": 1536},
139
+ )
140
+
141
+ adapter = embedding_adapter_from_type(embedding_config)
142
+
143
+ assert isinstance(adapter, LitellmEmbeddingAdapter)
144
+ assert adapter.embedding_config.description == "Test embedding configuration"
145
+
146
+
147
+ def test_embedding_adapter_from_type_with_additional_properties(
148
+ mock_provider_configs,
149
+ ):
150
+ """Test embedding adapter creation with additional properties."""
151
+ embedding_config = EmbeddingConfig(
152
+ name="test-embedding",
153
+ model_provider_name=ModelProviderName.openai,
154
+ model_name="text-embedding-3-small",
155
+ properties={
156
+ "dimensions": 1536,
157
+ "batch_size": 100,
158
+ "max_retries": 3,
159
+ },
160
+ )
161
+
162
+ adapter = embedding_adapter_from_type(embedding_config)
163
+
164
+ assert isinstance(adapter, LitellmEmbeddingAdapter)
165
+ assert adapter.embedding_config.properties["batch_size"] == 100
166
+ assert adapter.embedding_config.properties["max_retries"] == 3