kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +6 -0
- kiln_ai/adapters/adapter_registry.py +43 -226
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/eval_runner.py +6 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -3
- kiln_ai/adapters/eval/test_g_eval.py +1 -1
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +494 -0
- kiln_ai/adapters/ml_model_list.py +876 -18
- kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/provider_tools.py +190 -46
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/test_adapter_registry.py +579 -86
- kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
- kiln_ai/adapters/test_ml_model_list.py +202 -0
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +199 -8
- kiln_ai/adapters/test_remote_config.py +551 -56
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +16 -13
- kiln_ai/datamodel/basemodel.py +201 -4
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +27 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +317 -0
- kiln_ai/datamodel/project.py +33 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +270 -14
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +501 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +187 -1
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +58 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/base_tool.py +12 -3
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +158 -0
- kiln_ai/tools/mcp_server_tool.py +2 -2
- kiln_ai/tools/mcp_session_manager.py +51 -22
- kiln_ai/tools/rag_tools.py +164 -0
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +187 -227
- kiln_ai/tools/test_rag_tools.py +929 -0
- kiln_ai/tools/test_tool_registry.py +290 -7
- kiln_ai/tools/tool_registry.py +69 -16
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +2 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +59 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +86 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
- kiln_ai-0.22.0.dist-info/RECORD +213 -0
- kiln_ai-0.20.1.dist-info/RECORD +0 -138
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from
|
|
1
|
+
from os import getenv
|
|
2
|
+
from unittest.mock import Mock, patch
|
|
2
3
|
|
|
3
4
|
import pytest
|
|
4
5
|
|
|
@@ -6,17 +7,39 @@ from kiln_ai import datamodel
|
|
|
6
7
|
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
7
8
|
from kiln_ai.adapters.ml_model_list import ModelProviderName
|
|
8
9
|
from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
|
|
9
|
-
from kiln_ai.adapters.model_adapters.litellm_adapter import
|
|
10
|
-
|
|
10
|
+
from kiln_ai.adapters.model_adapters.litellm_adapter import (
|
|
11
|
+
LiteLlmAdapter,
|
|
12
|
+
LiteLlmConfig,
|
|
13
|
+
)
|
|
14
|
+
from kiln_ai.adapters.provider_tools import (
|
|
15
|
+
Config,
|
|
16
|
+
LiteLlmCoreConfig,
|
|
17
|
+
lite_llm_core_config_for_provider,
|
|
18
|
+
)
|
|
11
19
|
from kiln_ai.datamodel.datamodel_enums import StructuredOutputMode
|
|
12
20
|
from kiln_ai.datamodel.task import RunConfigProperties
|
|
13
21
|
|
|
14
22
|
|
|
15
23
|
@pytest.fixture
|
|
16
24
|
def mock_config():
|
|
17
|
-
with patch("kiln_ai.adapters.
|
|
25
|
+
with patch("kiln_ai.adapters.provider_tools.Config") as mock:
|
|
18
26
|
mock.shared.return_value.open_ai_api_key = "test-openai-key"
|
|
19
27
|
mock.shared.return_value.open_router_api_key = "test-openrouter-key"
|
|
28
|
+
mock.shared.return_value.groq_api_key = "test-groq-key"
|
|
29
|
+
mock.shared.return_value.bedrock_access_key = "test-bedrock-access-key"
|
|
30
|
+
mock.shared.return_value.bedrock_secret_key = "test-bedrock-secret-key"
|
|
31
|
+
mock.shared.return_value.huggingface_api_key = "test-huggingface-key"
|
|
32
|
+
mock.shared.return_value.ollama_base_url = "http://localhost:11434/v1"
|
|
33
|
+
mock.shared.return_value.fireworks_api_key = "test-fireworks-key"
|
|
34
|
+
mock.shared.return_value.anthropic_api_key = "test-anthropic-key"
|
|
35
|
+
mock.shared.return_value.gemini_api_key = "test-gemini-key"
|
|
36
|
+
mock.shared.return_value.vertex_project_id = "test-vertex-project-id"
|
|
37
|
+
mock.shared.return_value.vertex_location = "test-vertex-location"
|
|
38
|
+
mock.shared.return_value.together_api_key = "test-together-key"
|
|
39
|
+
mock.shared.return_value.azure_openai_api_key = "test-azure-openai-key"
|
|
40
|
+
mock.shared.return_value.azure_openai_endpoint = (
|
|
41
|
+
"https://test-azure-openai-endpoint.com/v1"
|
|
42
|
+
)
|
|
20
43
|
mock.shared.return_value.siliconflow_cn_api_key = "test-siliconflow-key"
|
|
21
44
|
mock.shared.return_value.docker_model_runner_base_url = (
|
|
22
45
|
"http://localhost:12434/engines/llama.cpp"
|
|
@@ -35,57 +58,97 @@ def basic_task():
|
|
|
35
58
|
)
|
|
36
59
|
|
|
37
60
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
61
|
+
def test_openai_adapter_creation(mock_config, basic_task):
|
|
62
|
+
with patch(
|
|
63
|
+
"kiln_ai.adapters.adapter_registry.lite_llm_core_config_for_provider"
|
|
64
|
+
) as mock_lite_llm_core_config_for_provider:
|
|
65
|
+
mock_lite_llm_core_config = LiteLlmCoreConfig(
|
|
66
|
+
additional_body_options={"api_key": "test-openai-key"},
|
|
67
|
+
)
|
|
68
|
+
mock_lite_llm_core_config_for_provider.return_value = mock_lite_llm_core_config
|
|
45
69
|
|
|
70
|
+
adapter = adapter_for_task(
|
|
71
|
+
kiln_task=basic_task,
|
|
72
|
+
run_config_properties=RunConfigProperties(
|
|
73
|
+
model_name="gpt-4",
|
|
74
|
+
model_provider_name=ModelProviderName.openai,
|
|
75
|
+
prompt_id="simple_prompt_builder",
|
|
76
|
+
structured_output_mode="json_schema",
|
|
77
|
+
),
|
|
78
|
+
)
|
|
46
79
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
model_name="gpt-4",
|
|
52
|
-
model_provider_name=ModelProviderName.openai,
|
|
53
|
-
prompt_id="simple_prompt_builder",
|
|
54
|
-
structured_output_mode="json_schema",
|
|
55
|
-
),
|
|
56
|
-
)
|
|
80
|
+
# Verify the connection details were accessed (not openai_compatible bypass)
|
|
81
|
+
mock_lite_llm_core_config_for_provider.assert_called_once_with(
|
|
82
|
+
ModelProviderName.openai, None
|
|
83
|
+
)
|
|
57
84
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
85
|
+
# Verify adapter configuration
|
|
86
|
+
assert isinstance(adapter, LiteLlmAdapter)
|
|
87
|
+
assert adapter.config.run_config_properties.model_name == "gpt-4"
|
|
88
|
+
assert adapter.config.base_url == mock_lite_llm_core_config.base_url
|
|
89
|
+
assert (
|
|
90
|
+
adapter.config.default_headers == mock_lite_llm_core_config.default_headers
|
|
91
|
+
)
|
|
92
|
+
assert (
|
|
93
|
+
adapter.config.additional_body_options
|
|
94
|
+
== mock_lite_llm_core_config.additional_body_options
|
|
95
|
+
)
|
|
96
|
+
assert (
|
|
97
|
+
adapter.config.run_config_properties.model_provider_name
|
|
98
|
+
== ModelProviderName.openai
|
|
99
|
+
)
|
|
100
|
+
assert adapter.config.base_url is None
|
|
101
|
+
assert adapter.config.default_headers is None
|
|
67
102
|
|
|
68
103
|
|
|
69
104
|
def test_openrouter_adapter_creation(mock_config, basic_task):
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
105
|
+
with patch(
|
|
106
|
+
"kiln_ai.adapters.adapter_registry.lite_llm_core_config_for_provider"
|
|
107
|
+
) as mock_lite_llm_core_config_for_provider:
|
|
108
|
+
mock_lite_llm_core_config = LiteLlmCoreConfig(
|
|
109
|
+
additional_body_options={"api_key": "test-openrouter-key"},
|
|
110
|
+
base_url="https://openrouter.ai/api/v1",
|
|
111
|
+
default_headers={
|
|
112
|
+
"HTTP-Referer": "https://kiln.tech/openrouter",
|
|
113
|
+
"X-Title": "KilnAI",
|
|
114
|
+
},
|
|
115
|
+
)
|
|
116
|
+
mock_lite_llm_core_config_for_provider.return_value = mock_lite_llm_core_config
|
|
79
117
|
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
118
|
+
adapter = adapter_for_task(
|
|
119
|
+
kiln_task=basic_task,
|
|
120
|
+
run_config_properties=RunConfigProperties(
|
|
121
|
+
model_name="anthropic/claude-3-opus",
|
|
122
|
+
model_provider_name=ModelProviderName.openrouter,
|
|
123
|
+
prompt_id="simple_prompt_builder",
|
|
124
|
+
structured_output_mode="json_schema",
|
|
125
|
+
),
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Verify the connection details were accessed (not openai_compatible bypass)
|
|
129
|
+
mock_lite_llm_core_config_for_provider.assert_called_once_with(
|
|
130
|
+
ModelProviderName.openrouter, None
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Verify adapter configuration including complex auth (headers + base_url)
|
|
134
|
+
assert isinstance(adapter, LiteLlmAdapter)
|
|
135
|
+
assert (
|
|
136
|
+
adapter.config.run_config_properties.model_name == "anthropic/claude-3-opus"
|
|
137
|
+
)
|
|
138
|
+
assert adapter.config.additional_body_options == {
|
|
139
|
+
"api_key": "test-openrouter-key"
|
|
140
|
+
}
|
|
141
|
+
assert adapter.config.base_url == "https://openrouter.ai/api/v1"
|
|
142
|
+
assert adapter.config.default_headers == {
|
|
143
|
+
"HTTP-Referer": "https://kiln.tech/openrouter",
|
|
144
|
+
"X-Title": "KilnAI",
|
|
145
|
+
}
|
|
146
|
+
assert (
|
|
147
|
+
adapter.config.run_config_properties.model_provider_name
|
|
148
|
+
== ModelProviderName.openrouter
|
|
149
|
+
)
|
|
87
150
|
assert adapter.config.default_headers == {
|
|
88
|
-
"HTTP-Referer": "https://
|
|
151
|
+
"HTTP-Referer": "https://kiln.tech/openrouter",
|
|
89
152
|
"X-Title": "KilnAI",
|
|
90
153
|
}
|
|
91
154
|
|
|
@@ -124,6 +187,13 @@ def test_siliconflow_adapter_creation(mock_config, basic_task):
|
|
|
124
187
|
ModelProviderName.amazon_bedrock,
|
|
125
188
|
ModelProviderName.ollama,
|
|
126
189
|
ModelProviderName.fireworks_ai,
|
|
190
|
+
ModelProviderName.anthropic,
|
|
191
|
+
ModelProviderName.gemini_api,
|
|
192
|
+
ModelProviderName.vertex,
|
|
193
|
+
ModelProviderName.together_ai,
|
|
194
|
+
ModelProviderName.azure_openai,
|
|
195
|
+
ModelProviderName.huggingface,
|
|
196
|
+
ModelProviderName.openrouter,
|
|
127
197
|
],
|
|
128
198
|
)
|
|
129
199
|
def test_openai_compatible_adapter_creation(mock_config, basic_task, provider):
|
|
@@ -188,34 +258,39 @@ def test_invalid_provider(mock_config, basic_task):
|
|
|
188
258
|
)
|
|
189
259
|
|
|
190
260
|
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
model_provider_name=ModelProviderName.openai_compatible,
|
|
202
|
-
prompt_id="simple_prompt_builder",
|
|
203
|
-
structured_output_mode="json_schema",
|
|
204
|
-
)
|
|
261
|
+
def test_openai_compatible_adapter(basic_task):
|
|
262
|
+
# patch Config.shared().openai_compatible_providers
|
|
263
|
+
with patch("kiln_ai.adapters.provider_tools.Config.shared") as mock_config_shared:
|
|
264
|
+
mock_config_shared.return_value.openai_compatible_providers = [
|
|
265
|
+
{
|
|
266
|
+
"name": "some-provider",
|
|
267
|
+
"base_url": "https://test.com/v1",
|
|
268
|
+
"api_key": "test-key",
|
|
269
|
+
}
|
|
270
|
+
]
|
|
205
271
|
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
272
|
+
adapter = adapter_for_task(
|
|
273
|
+
kiln_task=basic_task,
|
|
274
|
+
run_config_properties=RunConfigProperties(
|
|
275
|
+
model_name="some-provider::test-model",
|
|
276
|
+
model_provider_name=ModelProviderName.openai_compatible,
|
|
277
|
+
prompt_id="simple_prompt_builder",
|
|
278
|
+
structured_output_mode="json_schema",
|
|
279
|
+
),
|
|
280
|
+
)
|
|
215
281
|
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
282
|
+
assert isinstance(adapter, LiteLlmAdapter)
|
|
283
|
+
assert adapter.config.additional_body_options == {"api_key": "test-key"}
|
|
284
|
+
assert adapter.config.base_url == "https://test.com/v1"
|
|
285
|
+
assert adapter.config.run_config_properties.model_name == "test-model"
|
|
286
|
+
assert (
|
|
287
|
+
adapter.config.run_config_properties.model_provider_name
|
|
288
|
+
== "openai_compatible"
|
|
289
|
+
)
|
|
290
|
+
assert adapter.config.run_config_properties.prompt_id == "simple_prompt_builder"
|
|
291
|
+
assert (
|
|
292
|
+
adapter.config.run_config_properties.structured_output_mode == "json_schema"
|
|
293
|
+
)
|
|
219
294
|
|
|
220
295
|
|
|
221
296
|
def test_custom_openai_compatible_provider(mock_config, basic_task):
|
|
@@ -239,31 +314,449 @@ def test_custom_openai_compatible_provider(mock_config, basic_task):
|
|
|
239
314
|
)
|
|
240
315
|
|
|
241
316
|
|
|
242
|
-
|
|
317
|
+
@pytest.fixture
|
|
318
|
+
def mock_lite_llm_core_config_for_provider():
|
|
319
|
+
"""Mock lite_llm_core_config_for_provider to return predictable auth details."""
|
|
320
|
+
with patch(
|
|
321
|
+
"kiln_ai.adapters.adapter_registry.lite_llm_core_config_for_provider"
|
|
322
|
+
) as mock:
|
|
323
|
+
yield mock
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def test_adapter_for_task_core_provider_mapping(
|
|
327
|
+
mock_lite_llm_core_config_for_provider, basic_task
|
|
328
|
+
):
|
|
329
|
+
"""Test adapter_for_task correctly maps virtual providers to core providers."""
|
|
330
|
+
# Mock auth details for the underlying provider
|
|
331
|
+
mock_lite_llm_core_config = LiteLlmCoreConfig(
|
|
332
|
+
additional_body_options={"api_key": "test-openai-key"},
|
|
333
|
+
)
|
|
334
|
+
mock_lite_llm_core_config_for_provider.return_value = mock_lite_llm_core_config
|
|
335
|
+
|
|
336
|
+
# Use a virtual provider that should map to openai
|
|
337
|
+
with patch("kiln_ai.adapters.adapter_registry.core_provider") as mock_core_provider:
|
|
338
|
+
mock_core_provider.return_value = ModelProviderName.openai
|
|
339
|
+
|
|
340
|
+
adapter = adapter_for_task(
|
|
341
|
+
kiln_task=basic_task,
|
|
342
|
+
run_config_properties=RunConfigProperties(
|
|
343
|
+
model_name="fake-gpt",
|
|
344
|
+
model_provider_name=ModelProviderName.kiln_fine_tune,
|
|
345
|
+
prompt_id="simple_prompt_builder",
|
|
346
|
+
structured_output_mode="json_schema",
|
|
347
|
+
),
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# Verify core_provider was called to map virtual to actual provider
|
|
351
|
+
mock_core_provider.assert_called_once_with(
|
|
352
|
+
"fake-gpt", ModelProviderName.kiln_fine_tune
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# Verify auth was fetched for the mapped core provider
|
|
356
|
+
mock_lite_llm_core_config_for_provider.assert_called_once_with(
|
|
357
|
+
ModelProviderName.openai, None
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
# Verify adapter is created correctly
|
|
361
|
+
assert isinstance(adapter, LiteLlmAdapter)
|
|
362
|
+
assert (
|
|
363
|
+
adapter.config.additional_body_options
|
|
364
|
+
== mock_lite_llm_core_config.additional_body_options
|
|
365
|
+
)
|
|
366
|
+
assert adapter.config.base_url == mock_lite_llm_core_config.base_url
|
|
367
|
+
assert (
|
|
368
|
+
adapter.config.default_headers == mock_lite_llm_core_config.default_headers
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def test_adapter_for_task_preserves_run_config_properties(
|
|
373
|
+
mock_lite_llm_core_config_for_provider, basic_task
|
|
374
|
+
):
|
|
375
|
+
"""Test adapter_for_task preserves all run config properties correctly."""
|
|
376
|
+
mock_lite_llm_core_config = LiteLlmCoreConfig(
|
|
377
|
+
additional_body_options={"api_key": "test-key"},
|
|
378
|
+
)
|
|
379
|
+
mock_lite_llm_core_config_for_provider.return_value = mock_lite_llm_core_config
|
|
380
|
+
|
|
381
|
+
run_config_props = RunConfigProperties(
|
|
382
|
+
model_name="gpt-4",
|
|
383
|
+
model_provider_name=ModelProviderName.openai,
|
|
384
|
+
prompt_id="simple_prompt_builder",
|
|
385
|
+
structured_output_mode="function_calling",
|
|
386
|
+
temperature=0.7,
|
|
387
|
+
top_p=0.9,
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
adapter = adapter_for_task(
|
|
391
|
+
kiln_task=basic_task,
|
|
392
|
+
run_config_properties=run_config_props,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
# Verify all run config properties are preserved
|
|
396
|
+
assert adapter.config.run_config_properties.model_name == "gpt-4"
|
|
397
|
+
assert (
|
|
398
|
+
adapter.config.run_config_properties.model_provider_name
|
|
399
|
+
== ModelProviderName.openai
|
|
400
|
+
)
|
|
401
|
+
assert adapter.config.run_config_properties.prompt_id == "simple_prompt_builder"
|
|
402
|
+
assert (
|
|
403
|
+
adapter.config.run_config_properties.structured_output_mode
|
|
404
|
+
== "function_calling"
|
|
405
|
+
)
|
|
406
|
+
assert adapter.config.run_config_properties.temperature == 0.7
|
|
407
|
+
assert adapter.config.run_config_properties.top_p == 0.9
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def test_adapter_for_task_with_base_adapter_config(
|
|
411
|
+
mock_lite_llm_core_config_for_provider, basic_task
|
|
412
|
+
):
|
|
413
|
+
"""Test adapter_for_task correctly passes through base_adapter_config."""
|
|
414
|
+
mock_lite_llm_core_config = LiteLlmCoreConfig(
|
|
415
|
+
additional_body_options={"api_key": "test-key"},
|
|
416
|
+
)
|
|
417
|
+
mock_lite_llm_core_config_for_provider.return_value = mock_lite_llm_core_config
|
|
418
|
+
|
|
419
|
+
base_config = AdapterConfig(
|
|
420
|
+
allow_saving=False,
|
|
421
|
+
top_logprobs=5,
|
|
422
|
+
default_tags=["test-tag-1", "test-tag-2"],
|
|
423
|
+
)
|
|
424
|
+
|
|
243
425
|
adapter = adapter_for_task(
|
|
244
426
|
kiln_task=basic_task,
|
|
245
427
|
run_config_properties=RunConfigProperties(
|
|
246
|
-
model_name="
|
|
247
|
-
model_provider_name=ModelProviderName.
|
|
428
|
+
model_name="gpt-4",
|
|
429
|
+
model_provider_name=ModelProviderName.openai,
|
|
248
430
|
prompt_id="simple_prompt_builder",
|
|
249
431
|
structured_output_mode="json_schema",
|
|
250
432
|
),
|
|
433
|
+
base_adapter_config=base_config,
|
|
251
434
|
)
|
|
252
435
|
|
|
253
|
-
|
|
254
|
-
assert
|
|
436
|
+
# Verify base adapter config is preserved
|
|
437
|
+
assert adapter.base_adapter_config == base_config
|
|
438
|
+
assert adapter.base_adapter_config.allow_saving is False
|
|
439
|
+
assert adapter.base_adapter_config.top_logprobs == 5
|
|
440
|
+
assert adapter.base_adapter_config.default_tags == ["test-tag-1", "test-tag-2"]
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
@pytest.fixture
|
|
444
|
+
def comprehensive_mock_config():
|
|
445
|
+
"""Mock all config values for comprehensive testing."""
|
|
446
|
+
config_instance = Mock()
|
|
447
|
+
|
|
448
|
+
# Set up all config values that the original switch used
|
|
449
|
+
config_instance.open_router_api_key = "test-openrouter-key"
|
|
450
|
+
config_instance.open_ai_api_key = "test-openai-key"
|
|
451
|
+
config_instance.groq_api_key = "test-groq-key"
|
|
452
|
+
config_instance.bedrock_access_key = "test-aws-access-key"
|
|
453
|
+
config_instance.bedrock_secret_key = "test-aws-secret-key"
|
|
454
|
+
config_instance.ollama_base_url = "http://test-ollama:11434"
|
|
455
|
+
config_instance.fireworks_api_key = "test-fireworks-key"
|
|
456
|
+
config_instance.anthropic_api_key = "test-anthropic-key"
|
|
457
|
+
config_instance.gemini_api_key = "test-gemini-key"
|
|
458
|
+
config_instance.vertex_project_id = "test-vertex-project"
|
|
459
|
+
config_instance.vertex_location = "us-central1"
|
|
460
|
+
config_instance.together_api_key = "test-together-key"
|
|
461
|
+
config_instance.azure_openai_api_key = "test-azure-key"
|
|
462
|
+
config_instance.azure_openai_endpoint = "https://test.openai.azure.com"
|
|
463
|
+
config_instance.huggingface_api_key = "test-hf-key"
|
|
464
|
+
|
|
465
|
+
# Mock both import locations - the refactored code uses provider_tools.Config
|
|
466
|
+
# and the original switch recreation uses local Config import
|
|
467
|
+
with (
|
|
468
|
+
patch("kiln_ai.adapters.provider_tools.Config") as provider_tools_mock,
|
|
469
|
+
patch("kiln_ai.adapters.test_adapter_registry.Config") as test_mock,
|
|
470
|
+
):
|
|
471
|
+
provider_tools_mock.shared.return_value = config_instance
|
|
472
|
+
test_mock.shared.return_value = config_instance
|
|
473
|
+
|
|
474
|
+
yield provider_tools_mock
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def create_config_declarative(
|
|
478
|
+
provider_name: ModelProviderName, run_config_properties: RunConfigProperties
|
|
479
|
+
) -> LiteLlmConfig:
|
|
480
|
+
"""Regression test, but also easier to verify the config is what we expect for each provider."""
|
|
481
|
+
match provider_name:
|
|
482
|
+
case ModelProviderName.openrouter:
|
|
483
|
+
return LiteLlmConfig(
|
|
484
|
+
run_config_properties=run_config_properties,
|
|
485
|
+
base_url=getenv("OPENROUTER_BASE_URL")
|
|
486
|
+
or "https://openrouter.ai/api/v1",
|
|
487
|
+
default_headers={
|
|
488
|
+
"HTTP-Referer": "https://kiln.tech/openrouter",
|
|
489
|
+
"X-Title": "KilnAI",
|
|
490
|
+
},
|
|
491
|
+
additional_body_options={
|
|
492
|
+
"api_key": Config.shared().open_router_api_key,
|
|
493
|
+
},
|
|
494
|
+
)
|
|
495
|
+
case ModelProviderName.openai:
|
|
496
|
+
return LiteLlmConfig(
|
|
497
|
+
run_config_properties=run_config_properties,
|
|
498
|
+
additional_body_options={
|
|
499
|
+
"api_key": Config.shared().open_ai_api_key,
|
|
500
|
+
},
|
|
501
|
+
)
|
|
502
|
+
case ModelProviderName.groq:
|
|
503
|
+
return LiteLlmConfig(
|
|
504
|
+
run_config_properties=run_config_properties,
|
|
505
|
+
additional_body_options={
|
|
506
|
+
"api_key": Config.shared().groq_api_key,
|
|
507
|
+
},
|
|
508
|
+
)
|
|
509
|
+
case ModelProviderName.amazon_bedrock:
|
|
510
|
+
return LiteLlmConfig(
|
|
511
|
+
run_config_properties=run_config_properties,
|
|
512
|
+
additional_body_options={
|
|
513
|
+
"aws_access_key_id": Config.shared().bedrock_access_key,
|
|
514
|
+
"aws_secret_access_key": Config.shared().bedrock_secret_key,
|
|
515
|
+
"aws_region_name": "us-west-2",
|
|
516
|
+
},
|
|
517
|
+
)
|
|
518
|
+
case ModelProviderName.ollama:
|
|
519
|
+
ollama_base_url = (
|
|
520
|
+
Config.shared().ollama_base_url or "http://localhost:11434"
|
|
521
|
+
)
|
|
522
|
+
return LiteLlmConfig(
|
|
523
|
+
run_config_properties=run_config_properties,
|
|
524
|
+
base_url=ollama_base_url + "/v1",
|
|
525
|
+
additional_body_options={
|
|
526
|
+
"api_key": "NA",
|
|
527
|
+
},
|
|
528
|
+
)
|
|
529
|
+
case ModelProviderName.fireworks_ai:
|
|
530
|
+
return LiteLlmConfig(
|
|
531
|
+
run_config_properties=run_config_properties,
|
|
532
|
+
additional_body_options={
|
|
533
|
+
"api_key": Config.shared().fireworks_api_key,
|
|
534
|
+
},
|
|
535
|
+
)
|
|
536
|
+
case ModelProviderName.anthropic:
|
|
537
|
+
return LiteLlmConfig(
|
|
538
|
+
run_config_properties=run_config_properties,
|
|
539
|
+
additional_body_options={
|
|
540
|
+
"api_key": Config.shared().anthropic_api_key,
|
|
541
|
+
},
|
|
542
|
+
)
|
|
543
|
+
case ModelProviderName.gemini_api:
|
|
544
|
+
return LiteLlmConfig(
|
|
545
|
+
run_config_properties=run_config_properties,
|
|
546
|
+
additional_body_options={
|
|
547
|
+
"api_key": Config.shared().gemini_api_key,
|
|
548
|
+
},
|
|
549
|
+
)
|
|
550
|
+
case ModelProviderName.vertex:
|
|
551
|
+
return LiteLlmConfig(
|
|
552
|
+
run_config_properties=run_config_properties,
|
|
553
|
+
additional_body_options={
|
|
554
|
+
"vertex_project": Config.shared().vertex_project_id,
|
|
555
|
+
"vertex_location": Config.shared().vertex_location,
|
|
556
|
+
},
|
|
557
|
+
)
|
|
558
|
+
case ModelProviderName.together_ai:
|
|
559
|
+
return LiteLlmConfig(
|
|
560
|
+
run_config_properties=run_config_properties,
|
|
561
|
+
additional_body_options={
|
|
562
|
+
"api_key": Config.shared().together_api_key,
|
|
563
|
+
},
|
|
564
|
+
)
|
|
565
|
+
case ModelProviderName.azure_openai:
|
|
566
|
+
return LiteLlmConfig(
|
|
567
|
+
base_url=Config.shared().azure_openai_endpoint,
|
|
568
|
+
run_config_properties=run_config_properties,
|
|
569
|
+
additional_body_options={
|
|
570
|
+
"api_key": Config.shared().azure_openai_api_key,
|
|
571
|
+
"api_version": "2025-02-01-preview",
|
|
572
|
+
},
|
|
573
|
+
)
|
|
574
|
+
case ModelProviderName.huggingface:
|
|
575
|
+
return LiteLlmConfig(
|
|
576
|
+
run_config_properties=run_config_properties,
|
|
577
|
+
additional_body_options={
|
|
578
|
+
"api_key": Config.shared().huggingface_api_key,
|
|
579
|
+
},
|
|
580
|
+
)
|
|
581
|
+
case _:
|
|
582
|
+
raise ValueError(f"Test setup error: unsupported provider {provider_name}")
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
@pytest.mark.parametrize(
|
|
586
|
+
"provider_name,model_name",
|
|
587
|
+
[
|
|
588
|
+
(ModelProviderName.openrouter, "anthropic/claude-3-opus"),
|
|
589
|
+
(ModelProviderName.openai, "gpt-4"),
|
|
590
|
+
(ModelProviderName.groq, "llama3-8b-8192"),
|
|
591
|
+
(ModelProviderName.amazon_bedrock, "anthropic.claude-3-opus-20240229-v1:0"),
|
|
592
|
+
(ModelProviderName.ollama, "llama3.2"),
|
|
593
|
+
(
|
|
594
|
+
ModelProviderName.fireworks_ai,
|
|
595
|
+
"accounts/fireworks/models/llama-v3p1-8b-instruct",
|
|
596
|
+
),
|
|
597
|
+
(ModelProviderName.anthropic, "claude-3-opus-20240229"),
|
|
598
|
+
(ModelProviderName.gemini_api, "gemini-1.5-pro"),
|
|
599
|
+
(ModelProviderName.vertex, "gemini-1.5-pro"),
|
|
600
|
+
(ModelProviderName.together_ai, "meta-llama/Llama-3.2-3B-Instruct-Turbo"),
|
|
601
|
+
(ModelProviderName.azure_openai, "gpt-4"),
|
|
602
|
+
(ModelProviderName.huggingface, "microsoft/DialoGPT-medium"),
|
|
603
|
+
],
|
|
604
|
+
)
|
|
605
|
+
def test_adapter_for_task_matches_original_switch(
|
|
606
|
+
comprehensive_mock_config, basic_task, provider_name, model_name
|
|
607
|
+
):
|
|
608
|
+
"""
|
|
609
|
+
Regression test: Verify refactored adapter_for_task produces identical results
|
|
610
|
+
to the original switch statement for all providers.
|
|
611
|
+
"""
|
|
612
|
+
# Standard run config properties for testing
|
|
613
|
+
run_config_props = RunConfigProperties(
|
|
614
|
+
model_name=model_name,
|
|
615
|
+
model_provider_name=provider_name,
|
|
616
|
+
prompt_id="simple_prompt_builder",
|
|
617
|
+
structured_output_mode="json_schema",
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
# Get the adapter from the new refactored function
|
|
621
|
+
adapter = adapter_for_task(
|
|
622
|
+
kiln_task=basic_task,
|
|
623
|
+
run_config_properties=run_config_props,
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
# Create what the original switch would have produced
|
|
627
|
+
expected_config = create_config_declarative(provider_name, run_config_props)
|
|
628
|
+
|
|
629
|
+
# Compare the configurations field by field
|
|
630
|
+
actual_config = adapter.config
|
|
631
|
+
|
|
632
|
+
assert actual_config.run_config_properties == expected_config.run_config_properties
|
|
633
|
+
assert actual_config.base_url == expected_config.base_url
|
|
634
|
+
assert actual_config.default_headers == expected_config.default_headers
|
|
255
635
|
assert (
|
|
256
|
-
|
|
257
|
-
== ModelProviderName.kiln_fine_tune
|
|
636
|
+
actual_config.additional_body_options == expected_config.additional_body_options
|
|
258
637
|
)
|
|
259
|
-
# Kiln model name here, but the underlying openai model id below
|
|
260
|
-
assert adapter.config.run_config_properties.model_name == "proj::task::tune"
|
|
261
638
|
|
|
262
|
-
|
|
263
|
-
|
|
639
|
+
|
|
640
|
+
@patch.dict(
|
|
641
|
+
"os.environ", {"OPENROUTER_BASE_URL": "https://custom-openrouter.example.com"}
|
|
642
|
+
)
|
|
643
|
+
def test_adapter_for_task_matches_original_switch_openrouter_env_var(
|
|
644
|
+
comprehensive_mock_config, basic_task
|
|
645
|
+
):
|
|
646
|
+
"""
|
|
647
|
+
Test that OpenRouter respects the OPENROUTER_BASE_URL environment variable
|
|
648
|
+
exactly like the original switch statement did.
|
|
649
|
+
"""
|
|
650
|
+
run_config_props = RunConfigProperties(
|
|
651
|
+
model_name="anthropic/claude-3-opus",
|
|
652
|
+
model_provider_name=ModelProviderName.openrouter,
|
|
653
|
+
prompt_id="simple_prompt_builder",
|
|
654
|
+
structured_output_mode="json_schema",
|
|
264
655
|
)
|
|
265
|
-
|
|
266
|
-
|
|
656
|
+
|
|
657
|
+
# Get adapter from refactored function
|
|
658
|
+
adapter = adapter_for_task(
|
|
659
|
+
kiln_task=basic_task,
|
|
660
|
+
run_config_properties=run_config_props,
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
# Create what original switch would have produced
|
|
664
|
+
expected_config = create_config_declarative(
|
|
665
|
+
ModelProviderName.openrouter, run_config_props
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
# Both should use the custom environment variable
|
|
669
|
+
assert adapter.config.base_url == "https://custom-openrouter.example.com"
|
|
670
|
+
assert adapter.config.base_url == expected_config.base_url
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
def test_adapter_for_task_matches_original_switch_ollama_default_url(
|
|
674
|
+
comprehensive_mock_config, basic_task
|
|
675
|
+
):
|
|
676
|
+
"""
|
|
677
|
+
Test that Ollama falls back to default URL when none configured,
|
|
678
|
+
exactly like the original switch statement did.
|
|
679
|
+
"""
|
|
680
|
+
# Override mock to return None for ollama_base_url
|
|
681
|
+
comprehensive_mock_config.shared.return_value.ollama_base_url = None
|
|
682
|
+
|
|
683
|
+
run_config_props = RunConfigProperties(
|
|
684
|
+
model_name="llama3.2",
|
|
685
|
+
model_provider_name=ModelProviderName.ollama,
|
|
686
|
+
prompt_id="simple_prompt_builder",
|
|
687
|
+
structured_output_mode="json_schema",
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
# Get adapter from refactored function
|
|
691
|
+
adapter = adapter_for_task(
|
|
692
|
+
kiln_task=basic_task,
|
|
693
|
+
run_config_properties=run_config_props,
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
# Create what original switch would have produced
|
|
697
|
+
expected_config = create_config_declarative(
|
|
698
|
+
ModelProviderName.ollama, run_config_props
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
# Both should use the default localhost URL
|
|
702
|
+
assert adapter.config.base_url == "http://localhost:11434/v1"
|
|
703
|
+
assert adapter.config.base_url == expected_config.base_url
|
|
704
|
+
|
|
705
|
+
|
|
706
|
+
@pytest.fixture
|
|
707
|
+
def mock_shared_config():
|
|
708
|
+
with patch("kiln_ai.adapters.provider_tools.Config.shared") as mock:
|
|
709
|
+
config = Mock()
|
|
710
|
+
config.openai_compatible_providers = [
|
|
711
|
+
{
|
|
712
|
+
"name": "test_provider",
|
|
713
|
+
"base_url": "https://api.test.com",
|
|
714
|
+
"api_key": "test-key",
|
|
715
|
+
},
|
|
716
|
+
{
|
|
717
|
+
"name": "no_key_provider",
|
|
718
|
+
"base_url": "https://api.nokey.com",
|
|
719
|
+
},
|
|
720
|
+
]
|
|
721
|
+
mock.return_value = config
|
|
722
|
+
yield mock
|
|
723
|
+
|
|
724
|
+
|
|
725
|
+
def test_lite_llm_config_no_api_key(mock_shared_config):
|
|
726
|
+
"""Test provider creation without API key (should work as some providers don't require it, but should pass NA to LiteLLM as it requires one)"""
|
|
727
|
+
config = lite_llm_core_config_for_provider(
|
|
728
|
+
ModelProviderName.openai_compatible, "no_key_provider"
|
|
729
|
+
)
|
|
730
|
+
assert config is not None
|
|
731
|
+
assert config.additional_body_options == {"api_key": "NA"}
|
|
732
|
+
assert config.base_url == "https://api.nokey.com"
|
|
733
|
+
|
|
734
|
+
|
|
735
|
+
@pytest.mark.parametrize(
|
|
736
|
+
"provider_name",
|
|
737
|
+
[
|
|
738
|
+
ModelProviderName.kiln_fine_tune,
|
|
739
|
+
ModelProviderName.kiln_custom_registry,
|
|
740
|
+
],
|
|
741
|
+
)
|
|
742
|
+
def test_lite_llm_core_config_for_provider_virtual_providers(
|
|
743
|
+
mock_config, basic_task, provider_name
|
|
744
|
+
):
|
|
745
|
+
# patch core_provider to return None
|
|
746
|
+
with patch("kiln_ai.adapters.adapter_registry.core_provider") as mock_core_provider:
|
|
747
|
+
mock_core_provider.return_value = provider_name
|
|
748
|
+
|
|
749
|
+
# virtual providers are not supported and should raise an error
|
|
750
|
+
with pytest.raises(ValueError, match="not a core provider"):
|
|
751
|
+
adapter_for_task(
|
|
752
|
+
basic_task,
|
|
753
|
+
RunConfigProperties(
|
|
754
|
+
model_name="project::task::finetune",
|
|
755
|
+
model_provider_name=provider_name,
|
|
756
|
+
prompt_id="simple_prompt_builder",
|
|
757
|
+
structured_output_mode="json_schema",
|
|
758
|
+
),
|
|
759
|
+
)
|
|
267
760
|
|
|
268
761
|
|
|
269
762
|
def test_docker_model_runner_adapter_creation(mock_config, basic_task):
|