kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +6 -0
- kiln_ai/adapters/adapter_registry.py +43 -226
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/eval_runner.py +6 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -3
- kiln_ai/adapters/eval/test_g_eval.py +1 -1
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +494 -0
- kiln_ai/adapters/ml_model_list.py +876 -18
- kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/provider_tools.py +190 -46
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/test_adapter_registry.py +579 -86
- kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
- kiln_ai/adapters/test_ml_model_list.py +202 -0
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +199 -8
- kiln_ai/adapters/test_remote_config.py +551 -56
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +16 -13
- kiln_ai/datamodel/basemodel.py +201 -4
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +27 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +317 -0
- kiln_ai/datamodel/project.py +33 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +270 -14
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +501 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +187 -1
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +58 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/base_tool.py +12 -3
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +158 -0
- kiln_ai/tools/mcp_server_tool.py +2 -2
- kiln_ai/tools/mcp_session_manager.py +51 -22
- kiln_ai/tools/rag_tools.py +164 -0
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +187 -227
- kiln_ai/tools/test_rag_tools.py +929 -0
- kiln_ai/tools/test_tool_registry.py +290 -7
- kiln_ai/tools/tool_registry.py +69 -16
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +2 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +59 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +86 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
- kiln_ai-0.22.0.dist-info/RECORD +213 -0
- kiln_ai-0.20.1.dist-info/RECORD +0 -138
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,13 +1,85 @@
|
|
|
1
1
|
import json
|
|
2
|
+
from unittest.mock import patch
|
|
2
3
|
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from kiln_ai.adapters.ml_embedding_model_list import (
|
|
7
|
+
KilnEmbeddingModel,
|
|
8
|
+
KilnEmbeddingModelProvider,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
# Mock data for testing - using proper Pydantic model instances
|
|
12
|
+
from kiln_ai.adapters.ml_model_list import KilnModel, KilnModelProvider
|
|
3
13
|
from kiln_ai.adapters.ollama_tools import (
|
|
4
14
|
OllamaConnection,
|
|
15
|
+
ollama_embedding_model_installed,
|
|
5
16
|
ollama_model_installed,
|
|
6
17
|
parse_ollama_tags,
|
|
7
18
|
)
|
|
19
|
+
from kiln_ai.datamodel.datamodel_enums import ModelProviderName
|
|
8
20
|
|
|
21
|
+
MOCK_BUILT_IN_MODELS = [
|
|
22
|
+
KilnModel(
|
|
23
|
+
family="phi",
|
|
24
|
+
name="phi3.5",
|
|
25
|
+
friendly_name="phi3.5",
|
|
26
|
+
providers=[
|
|
27
|
+
KilnModelProvider(
|
|
28
|
+
name=ModelProviderName.ollama,
|
|
29
|
+
model_id="phi3.5",
|
|
30
|
+
ollama_model_aliases=None,
|
|
31
|
+
)
|
|
32
|
+
],
|
|
33
|
+
),
|
|
34
|
+
KilnModel(
|
|
35
|
+
family="gemma",
|
|
36
|
+
name="gemma2",
|
|
37
|
+
friendly_name="gemma2",
|
|
38
|
+
providers=[
|
|
39
|
+
KilnModelProvider(
|
|
40
|
+
name=ModelProviderName.ollama,
|
|
41
|
+
model_id="gemma2:2b",
|
|
42
|
+
ollama_model_aliases=None,
|
|
43
|
+
)
|
|
44
|
+
],
|
|
45
|
+
),
|
|
46
|
+
KilnModel(
|
|
47
|
+
family="llama",
|
|
48
|
+
name="llama3.1",
|
|
49
|
+
friendly_name="llama3.1",
|
|
50
|
+
providers=[
|
|
51
|
+
KilnModelProvider(
|
|
52
|
+
name=ModelProviderName.ollama,
|
|
53
|
+
model_id="llama3.1",
|
|
54
|
+
ollama_model_aliases=None,
|
|
55
|
+
)
|
|
56
|
+
],
|
|
57
|
+
),
|
|
58
|
+
]
|
|
9
59
|
|
|
10
|
-
|
|
60
|
+
MOCK_BUILT_IN_EMBEDDING_MODELS = [
|
|
61
|
+
KilnEmbeddingModel(
|
|
62
|
+
family="gemma",
|
|
63
|
+
name="embeddinggemma",
|
|
64
|
+
friendly_name="embeddinggemma",
|
|
65
|
+
providers=[
|
|
66
|
+
KilnEmbeddingModelProvider(
|
|
67
|
+
name=ModelProviderName.ollama,
|
|
68
|
+
model_id="embeddinggemma:300m",
|
|
69
|
+
n_dimensions=768,
|
|
70
|
+
ollama_model_aliases=["embeddinggemma"],
|
|
71
|
+
)
|
|
72
|
+
],
|
|
73
|
+
),
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
|
|
78
|
+
@patch(
|
|
79
|
+
"kiln_ai.adapters.ollama_tools.built_in_embedding_models",
|
|
80
|
+
MOCK_BUILT_IN_EMBEDDING_MODELS,
|
|
81
|
+
)
|
|
82
|
+
def test_parse_ollama_tags_models():
|
|
11
83
|
json_response = '{"models":[{"name":"scosman_net","model":"scosman_net:latest"},{"name":"phi3.5:latest","model":"phi3.5:latest","modified_at":"2024-10-02T12:04:35.191519822-04:00","size":2176178843,"digest":"61819fb370a3c1a9be6694869331e5f85f867a079e9271d66cb223acb81d04ba","details":{"parent_model":"","format":"gguf","family":"phi3","families":["phi3"],"parameter_size":"3.8B","quantization_level":"Q4_0"}},{"name":"gemma2:2b","model":"gemma2:2b","modified_at":"2024-09-09T16:46:38.64348929-04:00","size":1629518495,"digest":"8ccf136fdd5298f3ffe2d69862750ea7fb56555fa4d5b18c04e3fa4d82ee09d7","details":{"parent_model":"","format":"gguf","family":"gemma2","families":["gemma2"],"parameter_size":"2.6B","quantization_level":"Q4_0"}},{"name":"llama3.1:latest","model":"llama3.1:latest","modified_at":"2024-09-01T17:19:43.481523695-04:00","size":4661230720,"digest":"f66fc8dc39ea206e03ff6764fcc696b1b4dfb693f0b6ef751731dd4e6269046e","details":{"parent_model":"","format":"gguf","family":"llama","families":["llama"],"parameter_size":"8.0B","quantization_level":"Q4_0"}}]}'
|
|
12
84
|
tags = json.loads(json_response)
|
|
13
85
|
conn = parse_ollama_tags(tags)
|
|
@@ -16,7 +88,52 @@ def test_parse_ollama_tags_no_models():
|
|
|
16
88
|
assert "llama3.1:latest" in conn.supported_models
|
|
17
89
|
assert "scosman_net:latest" in conn.untested_models
|
|
18
90
|
|
|
91
|
+
# there should be no embedding models because the tags response does not include any embedding models
|
|
92
|
+
# that are in the built-in embedding models list
|
|
93
|
+
assert len(conn.supported_embedding_models) == 0
|
|
19
94
|
|
|
95
|
+
|
|
96
|
+
@patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
|
|
97
|
+
@patch(
|
|
98
|
+
"kiln_ai.adapters.ollama_tools.built_in_embedding_models",
|
|
99
|
+
MOCK_BUILT_IN_EMBEDDING_MODELS,
|
|
100
|
+
)
|
|
101
|
+
@pytest.mark.parametrize("json_response", ["{}", '{"models": []}'])
|
|
102
|
+
def test_parse_ollama_tags_no_models(json_response):
|
|
103
|
+
tags = json.loads(json_response)
|
|
104
|
+
conn = parse_ollama_tags(tags)
|
|
105
|
+
assert (
|
|
106
|
+
conn.message
|
|
107
|
+
== "Ollama is running, but no supported models are installed. Install one or more supported model, like 'ollama pull phi3.5'."
|
|
108
|
+
)
|
|
109
|
+
assert len(conn.supported_models) == 0
|
|
110
|
+
assert len(conn.untested_models) == 0
|
|
111
|
+
assert len(conn.supported_embedding_models) == 0
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
|
|
115
|
+
@patch(
|
|
116
|
+
"kiln_ai.adapters.ollama_tools.built_in_embedding_models",
|
|
117
|
+
MOCK_BUILT_IN_EMBEDDING_MODELS,
|
|
118
|
+
)
|
|
119
|
+
def test_parse_ollama_tags_empty_models():
|
|
120
|
+
"""Test parsing Ollama tags response with empty models list"""
|
|
121
|
+
json_response = '{"models": []}'
|
|
122
|
+
tags = json.loads(json_response)
|
|
123
|
+
conn = parse_ollama_tags(tags)
|
|
124
|
+
|
|
125
|
+
# Check that connection indicates no supported models
|
|
126
|
+
assert conn.supported_models == []
|
|
127
|
+
assert conn.untested_models == []
|
|
128
|
+
assert conn.supported_embedding_models == []
|
|
129
|
+
assert "no supported models are installed" in conn.message
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
|
|
133
|
+
@patch(
|
|
134
|
+
"kiln_ai.adapters.ollama_tools.built_in_embedding_models",
|
|
135
|
+
MOCK_BUILT_IN_EMBEDDING_MODELS,
|
|
136
|
+
)
|
|
20
137
|
def test_parse_ollama_tags_only_untested_models():
|
|
21
138
|
json_response = '{"models":[{"name":"scosman_net","model":"scosman_net:latest"}]}'
|
|
22
139
|
tags = json.loads(json_response)
|
|
@@ -24,12 +141,17 @@ def test_parse_ollama_tags_only_untested_models():
|
|
|
24
141
|
assert conn.supported_models == []
|
|
25
142
|
assert conn.untested_models == ["scosman_net:latest"]
|
|
26
143
|
|
|
144
|
+
# there should be no embedding models because the tags response does not include any embedding models
|
|
145
|
+
# that are in the built-in embedding models list
|
|
146
|
+
assert len(conn.supported_embedding_models) == 0
|
|
147
|
+
|
|
27
148
|
|
|
28
149
|
def test_ollama_model_installed():
|
|
29
150
|
conn = OllamaConnection(
|
|
30
151
|
supported_models=["phi3.5:latest", "gemma2:2b", "llama3.1:latest"],
|
|
31
152
|
message="Connected",
|
|
32
153
|
untested_models=["scosman_net:latest"],
|
|
154
|
+
supported_embedding_models=["embeddinggemma:300m"],
|
|
33
155
|
)
|
|
34
156
|
assert ollama_model_installed(conn, "phi3.5:latest")
|
|
35
157
|
assert ollama_model_installed(conn, "phi3.5")
|
|
@@ -39,3 +161,220 @@ def test_ollama_model_installed():
|
|
|
39
161
|
assert ollama_model_installed(conn, "scosman_net:latest")
|
|
40
162
|
assert ollama_model_installed(conn, "scosman_net")
|
|
41
163
|
assert not ollama_model_installed(conn, "unknown_model")
|
|
164
|
+
|
|
165
|
+
# use the ollama_embedding_model_installed for testing embedding models installed, not ollama_model_installed
|
|
166
|
+
assert not ollama_model_installed(conn, "embeddinggemma:300m")
|
|
167
|
+
assert not ollama_model_installed(conn, "embeddinggemma")
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def test_ollama_model_installed_embedding_models():
|
|
171
|
+
conn = OllamaConnection(
|
|
172
|
+
message="Connected",
|
|
173
|
+
supported_models=["phi3.5:latest", "gemma2:2b", "llama3.1:latest"],
|
|
174
|
+
untested_models=["scosman_net:latest"],
|
|
175
|
+
supported_embedding_models=["embeddinggemma:300m", "embeddinggemma:latest"],
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
assert ollama_embedding_model_installed(conn, "embeddinggemma:300m")
|
|
179
|
+
assert ollama_embedding_model_installed(conn, "embeddinggemma:latest")
|
|
180
|
+
assert not ollama_embedding_model_installed(conn, "unknown_embedding")
|
|
181
|
+
|
|
182
|
+
# use the ollama_model_installed for testing regular models installed, not ollama_embedding_model_installed
|
|
183
|
+
assert not ollama_embedding_model_installed(conn, "phi3.5:latest")
|
|
184
|
+
assert not ollama_embedding_model_installed(conn, "gemma2:2b")
|
|
185
|
+
assert not ollama_embedding_model_installed(conn, "llama3.1:latest")
|
|
186
|
+
assert not ollama_embedding_model_installed(conn, "scosman_net:latest")
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
|
|
190
|
+
@patch(
|
|
191
|
+
"kiln_ai.adapters.ollama_tools.built_in_embedding_models",
|
|
192
|
+
MOCK_BUILT_IN_EMBEDDING_MODELS,
|
|
193
|
+
)
|
|
194
|
+
def test_parse_ollama_tags_with_embedding_models():
|
|
195
|
+
"""Test parsing Ollama tags response that includes embedding models"""
|
|
196
|
+
json_response = """{
|
|
197
|
+
"models": [
|
|
198
|
+
{
|
|
199
|
+
"name": "phi3.5:latest",
|
|
200
|
+
"model": "phi3.5:latest"
|
|
201
|
+
},
|
|
202
|
+
{
|
|
203
|
+
"name": "embeddinggemma:300m",
|
|
204
|
+
"model": "embeddinggemma:300m"
|
|
205
|
+
},
|
|
206
|
+
{
|
|
207
|
+
"name": "embeddinggemma:latest",
|
|
208
|
+
"model": "embeddinggemma:latest"
|
|
209
|
+
},
|
|
210
|
+
{
|
|
211
|
+
"name": "unknown_embedding:latest",
|
|
212
|
+
"model": "unknown_embedding:latest"
|
|
213
|
+
}
|
|
214
|
+
]
|
|
215
|
+
}"""
|
|
216
|
+
tags = json.loads(json_response)
|
|
217
|
+
conn = parse_ollama_tags(tags)
|
|
218
|
+
|
|
219
|
+
# Check that embedding models are properly categorized
|
|
220
|
+
assert "embeddinggemma:300m" in conn.supported_embedding_models
|
|
221
|
+
assert "embeddinggemma:latest" in conn.supported_embedding_models
|
|
222
|
+
|
|
223
|
+
# Check that regular models are still parsed correctly
|
|
224
|
+
assert "phi3.5:latest" in conn.supported_models
|
|
225
|
+
|
|
226
|
+
# Check that embedding models are NOT in the main model lists
|
|
227
|
+
assert "embeddinggemma:300m" not in conn.supported_models
|
|
228
|
+
assert "embeddinggemma:latest" not in conn.supported_models
|
|
229
|
+
assert "embeddinggemma:300m" not in conn.untested_models
|
|
230
|
+
assert "embeddinggemma:latest" not in conn.untested_models
|
|
231
|
+
|
|
232
|
+
# we assume the unknown models are normal models, not embedding models - because
|
|
233
|
+
# we don't support untested embedding models currently
|
|
234
|
+
assert "unknown_embedding:latest" not in conn.supported_embedding_models
|
|
235
|
+
assert "unknown_embedding:latest" in conn.untested_models
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
|
|
239
|
+
@patch(
|
|
240
|
+
"kiln_ai.adapters.ollama_tools.built_in_embedding_models",
|
|
241
|
+
MOCK_BUILT_IN_EMBEDDING_MODELS,
|
|
242
|
+
)
|
|
243
|
+
def test_parse_ollama_tags_embedding_model_aliases():
|
|
244
|
+
"""Test parsing Ollama tags response with embedding model aliases"""
|
|
245
|
+
json_response = """{
|
|
246
|
+
"models": [
|
|
247
|
+
{
|
|
248
|
+
"name": "embeddinggemma",
|
|
249
|
+
"model": "embeddinggemma"
|
|
250
|
+
}
|
|
251
|
+
]
|
|
252
|
+
}"""
|
|
253
|
+
tags = json.loads(json_response)
|
|
254
|
+
conn = parse_ollama_tags(tags)
|
|
255
|
+
|
|
256
|
+
# Check that embedding model aliases are recognized
|
|
257
|
+
assert "embeddinggemma" in conn.supported_embedding_models
|
|
258
|
+
|
|
259
|
+
# Check that embedding model aliases are NOT in the main model lists
|
|
260
|
+
assert "embeddinggemma" not in conn.supported_models
|
|
261
|
+
assert "embeddinggemma" not in conn.untested_models
|
|
262
|
+
|
|
263
|
+
assert len(conn.supported_models) == 0
|
|
264
|
+
assert len(conn.untested_models) == 0
|
|
265
|
+
assert len(conn.supported_embedding_models) == 1
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
@patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
|
|
269
|
+
@patch(
|
|
270
|
+
"kiln_ai.adapters.ollama_tools.built_in_embedding_models",
|
|
271
|
+
MOCK_BUILT_IN_EMBEDDING_MODELS,
|
|
272
|
+
)
|
|
273
|
+
def test_parse_ollama_tags_only_embedding_models():
|
|
274
|
+
"""Test parsing Ollama tags response with only embedding models"""
|
|
275
|
+
json_response = """{
|
|
276
|
+
"models": [
|
|
277
|
+
{
|
|
278
|
+
"name": "embeddinggemma:300m",
|
|
279
|
+
"model": "embeddinggemma:300m"
|
|
280
|
+
}
|
|
281
|
+
]
|
|
282
|
+
}"""
|
|
283
|
+
tags = json.loads(json_response)
|
|
284
|
+
conn = parse_ollama_tags(tags)
|
|
285
|
+
|
|
286
|
+
# Check that embedding models are found but no regular models
|
|
287
|
+
assert "embeddinggemma:300m" in conn.supported_embedding_models
|
|
288
|
+
assert conn.supported_models == []
|
|
289
|
+
assert conn.untested_models == []
|
|
290
|
+
|
|
291
|
+
# Check that embedding models are NOT in the main model lists
|
|
292
|
+
assert "embeddinggemma:300m" not in conn.supported_models
|
|
293
|
+
assert "embeddinggemma:300m" not in conn.untested_models
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def test_ollama_connection_all_embedding_models():
|
|
297
|
+
"""Test OllamaConnection.all_embedding_models() method"""
|
|
298
|
+
conn = OllamaConnection(
|
|
299
|
+
message="Connected",
|
|
300
|
+
supported_models=["phi3.5:latest"],
|
|
301
|
+
untested_models=["unknown:latest"],
|
|
302
|
+
supported_embedding_models=["embeddinggemma:300m", "embeddinggemma:latest"],
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
embedding_models = conn.all_embedding_models()
|
|
306
|
+
assert embedding_models == ["embeddinggemma:300m", "embeddinggemma:latest"]
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def test_ollama_connection_empty_embedding_models():
|
|
310
|
+
"""Test OllamaConnection.all_embedding_models() with empty list"""
|
|
311
|
+
conn = OllamaConnection(
|
|
312
|
+
message="Connected",
|
|
313
|
+
supported_models=["phi3.5:latest"],
|
|
314
|
+
untested_models=["unknown:latest"],
|
|
315
|
+
supported_embedding_models=[],
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
embedding_models = conn.all_embedding_models()
|
|
319
|
+
assert embedding_models == []
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
@patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
|
|
323
|
+
@patch(
|
|
324
|
+
"kiln_ai.adapters.ollama_tools.built_in_embedding_models",
|
|
325
|
+
MOCK_BUILT_IN_EMBEDDING_MODELS,
|
|
326
|
+
)
|
|
327
|
+
def test_parse_ollama_tags_mixed_models_and_embeddings():
|
|
328
|
+
"""Test parsing response with mix of regular models, embedding models, and unknown models"""
|
|
329
|
+
json_response = """{
|
|
330
|
+
"models": [
|
|
331
|
+
{
|
|
332
|
+
"name": "phi3.5:latest",
|
|
333
|
+
"model": "phi3.5:latest"
|
|
334
|
+
},
|
|
335
|
+
{
|
|
336
|
+
"name": "gemma2:2b",
|
|
337
|
+
"model": "gemma2:2b"
|
|
338
|
+
},
|
|
339
|
+
{
|
|
340
|
+
"name": "embeddinggemma:300m",
|
|
341
|
+
"model": "embeddinggemma:300m"
|
|
342
|
+
},
|
|
343
|
+
{
|
|
344
|
+
"name": "embeddinggemma",
|
|
345
|
+
"model": "embeddinggemma"
|
|
346
|
+
},
|
|
347
|
+
{
|
|
348
|
+
"name": "unknown_model:latest",
|
|
349
|
+
"model": "unknown_model:latest"
|
|
350
|
+
},
|
|
351
|
+
{
|
|
352
|
+
"name": "unknown_embedding:latest",
|
|
353
|
+
"model": "unknown_embedding:latest"
|
|
354
|
+
}
|
|
355
|
+
]
|
|
356
|
+
}"""
|
|
357
|
+
tags = json.loads(json_response)
|
|
358
|
+
conn = parse_ollama_tags(tags)
|
|
359
|
+
|
|
360
|
+
# Check regular models
|
|
361
|
+
assert "phi3.5:latest" in conn.supported_models
|
|
362
|
+
assert "gemma2:2b" in conn.supported_models
|
|
363
|
+
assert "unknown_model:latest" in conn.untested_models
|
|
364
|
+
|
|
365
|
+
# Check embedding models
|
|
366
|
+
assert "embeddinggemma:300m" in conn.supported_embedding_models
|
|
367
|
+
assert "embeddinggemma" in conn.supported_embedding_models
|
|
368
|
+
|
|
369
|
+
# Check that embedding models are NOT in the main model lists
|
|
370
|
+
assert "embeddinggemma:300m" not in conn.supported_models
|
|
371
|
+
assert "embeddinggemma" not in conn.supported_models
|
|
372
|
+
assert "embeddinggemma:300m" not in conn.untested_models
|
|
373
|
+
assert "embeddinggemma" not in conn.untested_models
|
|
374
|
+
|
|
375
|
+
# Unknown embedding models should not appear in supported_embedding_models
|
|
376
|
+
assert "unknown_embedding:latest" not in conn.supported_embedding_models
|
|
377
|
+
|
|
378
|
+
# Unknown embedding models should appear in untested_models (since they're not recognized as embeddings)
|
|
379
|
+
assert "unknown_embedding:latest" not in conn.supported_models
|
|
380
|
+
assert "unknown_embedding:latest" in conn.untested_models
|
|
@@ -359,7 +359,7 @@ def test_prompt_builder_from_id(task_with_examples):
|
|
|
359
359
|
|
|
360
360
|
with pytest.raises(
|
|
361
361
|
ValueError,
|
|
362
|
-
match="Invalid fine-tune ID format. Expected 'project_id::task_id::fine_tune_id'",
|
|
362
|
+
match=r"Invalid fine-tune ID format. Expected 'project_id::task_id::fine_tune_id'",
|
|
363
363
|
):
|
|
364
364
|
prompt_builder_from_id("fine_tune_prompt::123", task)
|
|
365
365
|
|
|
@@ -2,6 +2,7 @@ from unittest.mock import AsyncMock, Mock, patch
|
|
|
2
2
|
|
|
3
3
|
import pytest
|
|
4
4
|
|
|
5
|
+
from kiln_ai.adapters.adapter_registry import litellm_core_provider_config
|
|
5
6
|
from kiln_ai.adapters.docker_model_runner_tools import DockerModelRunnerConnection
|
|
6
7
|
from kiln_ai.adapters.ml_model_list import (
|
|
7
8
|
KilnModel,
|
|
@@ -11,6 +12,7 @@ from kiln_ai.adapters.ml_model_list import (
|
|
|
11
12
|
)
|
|
12
13
|
from kiln_ai.adapters.ollama_tools import OllamaConnection
|
|
13
14
|
from kiln_ai.adapters.provider_tools import (
|
|
15
|
+
LiteLlmCoreConfig,
|
|
14
16
|
builtin_model_from,
|
|
15
17
|
check_provider_warnings,
|
|
16
18
|
core_provider,
|
|
@@ -19,7 +21,7 @@ from kiln_ai.adapters.provider_tools import (
|
|
|
19
21
|
finetune_provider_model,
|
|
20
22
|
get_model_and_provider,
|
|
21
23
|
kiln_model_provider_from,
|
|
22
|
-
|
|
24
|
+
lite_llm_core_config_for_provider,
|
|
23
25
|
lite_llm_provider_model,
|
|
24
26
|
parse_custom_model_id,
|
|
25
27
|
provider_enabled,
|
|
@@ -604,7 +606,7 @@ def test_openai_compatible_provider_config(mock_shared_config):
|
|
|
604
606
|
"""Test successful creation of an OpenAI compatible provider"""
|
|
605
607
|
model_id = "test_provider::gpt-4"
|
|
606
608
|
|
|
607
|
-
config =
|
|
609
|
+
config = litellm_core_provider_config(
|
|
608
610
|
RunConfigProperties(
|
|
609
611
|
model_name=model_id,
|
|
610
612
|
model_provider_name=ModelProviderName.openai_compatible,
|
|
@@ -639,10 +641,10 @@ def test_lite_llm_config_no_api_key(mock_shared_config):
|
|
|
639
641
|
"""Test provider creation without API key (should work as some providers don't require it, but should pass NA to LiteLLM as it requires one)"""
|
|
640
642
|
model_id = "no_key_provider::gpt-4"
|
|
641
643
|
|
|
642
|
-
config =
|
|
644
|
+
config = litellm_core_provider_config(
|
|
643
645
|
RunConfigProperties(
|
|
644
646
|
model_name=model_id,
|
|
645
|
-
model_provider_name=ModelProviderName.
|
|
647
|
+
model_provider_name=ModelProviderName.openai_compatible,
|
|
646
648
|
prompt_id="simple_prompt_builder",
|
|
647
649
|
structured_output_mode="json_schema",
|
|
648
650
|
)
|
|
@@ -660,7 +662,7 @@ def test_lite_llm_config_no_api_key(mock_shared_config):
|
|
|
660
662
|
def test_lite_llm_config_invalid_id():
|
|
661
663
|
"""Test handling of invalid model ID format"""
|
|
662
664
|
with pytest.raises(ValueError) as exc_info:
|
|
663
|
-
|
|
665
|
+
litellm_core_provider_config(
|
|
664
666
|
RunConfigProperties(
|
|
665
667
|
model_name="invalid-id-format",
|
|
666
668
|
model_provider_name=ModelProviderName.openai_compatible,
|
|
@@ -678,7 +680,7 @@ def test_lite_llm_config_no_providers(mock_shared_config):
|
|
|
678
680
|
mock_shared_config.return_value.openai_compatible_providers = None
|
|
679
681
|
|
|
680
682
|
with pytest.raises(ValueError) as exc_info:
|
|
681
|
-
|
|
683
|
+
litellm_core_provider_config(
|
|
682
684
|
RunConfigProperties(
|
|
683
685
|
model_name="test_provider::gpt-4",
|
|
684
686
|
model_provider_name=ModelProviderName.openai_compatible,
|
|
@@ -692,7 +694,7 @@ def test_lite_llm_config_no_providers(mock_shared_config):
|
|
|
692
694
|
def test_lite_llm_config_provider_not_found(mock_shared_config):
|
|
693
695
|
"""Test handling of non-existent provider"""
|
|
694
696
|
with pytest.raises(ValueError) as exc_info:
|
|
695
|
-
|
|
697
|
+
litellm_core_provider_config(
|
|
696
698
|
RunConfigProperties(
|
|
697
699
|
model_name="unknown_provider::gpt-4",
|
|
698
700
|
model_provider_name=ModelProviderName.openai_compatible,
|
|
@@ -715,7 +717,7 @@ def test_lite_llm_config_no_base_url(mock_shared_config):
|
|
|
715
717
|
]
|
|
716
718
|
|
|
717
719
|
with pytest.raises(ValueError) as exc_info:
|
|
718
|
-
|
|
720
|
+
litellm_core_provider_config(
|
|
719
721
|
RunConfigProperties(
|
|
720
722
|
model_name="test_provider::gpt-4",
|
|
721
723
|
model_provider_name=ModelProviderName.openai_compatible,
|
|
@@ -934,6 +936,195 @@ def test_finetune_provider_model_vertex_ai(mock_project, mock_task, mock_finetun
|
|
|
934
936
|
assert provider.structured_output_mode == StructuredOutputMode.json_mode
|
|
935
937
|
|
|
936
938
|
|
|
939
|
+
@pytest.fixture
|
|
940
|
+
def mock_config_for_lite_llm_core_config():
|
|
941
|
+
with patch("kiln_ai.adapters.provider_tools.Config") as mock:
|
|
942
|
+
config_instance = Mock()
|
|
943
|
+
mock.shared.return_value = config_instance
|
|
944
|
+
|
|
945
|
+
# Set up all the config values
|
|
946
|
+
config_instance.open_router_api_key = "test-openrouter-key"
|
|
947
|
+
config_instance.open_ai_api_key = "test-openai-key"
|
|
948
|
+
config_instance.groq_api_key = "test-groq-key"
|
|
949
|
+
config_instance.bedrock_access_key = "test-aws-access-key"
|
|
950
|
+
config_instance.bedrock_secret_key = "test-aws-secret-key"
|
|
951
|
+
config_instance.ollama_base_url = "http://test-ollama:11434"
|
|
952
|
+
config_instance.fireworks_api_key = "test-fireworks-key"
|
|
953
|
+
config_instance.anthropic_api_key = "test-anthropic-key"
|
|
954
|
+
config_instance.gemini_api_key = "test-gemini-key"
|
|
955
|
+
config_instance.vertex_project_id = "test-vertex-project"
|
|
956
|
+
config_instance.vertex_location = "us-central1"
|
|
957
|
+
config_instance.together_api_key = "test-together-key"
|
|
958
|
+
config_instance.azure_openai_api_key = "test-azure-key"
|
|
959
|
+
config_instance.azure_openai_endpoint = "https://test.openai.azure.com"
|
|
960
|
+
config_instance.huggingface_api_key = "test-hf-key"
|
|
961
|
+
|
|
962
|
+
yield mock
|
|
963
|
+
|
|
964
|
+
|
|
965
|
+
@pytest.mark.parametrize(
|
|
966
|
+
"provider_name,expected_config",
|
|
967
|
+
[
|
|
968
|
+
(
|
|
969
|
+
ModelProviderName.openrouter,
|
|
970
|
+
LiteLlmCoreConfig(
|
|
971
|
+
base_url="https://openrouter.ai/api/v1",
|
|
972
|
+
additional_body_options={
|
|
973
|
+
"api_key": "test-openrouter-key",
|
|
974
|
+
},
|
|
975
|
+
default_headers={
|
|
976
|
+
"HTTP-Referer": "https://kiln.tech/openrouter",
|
|
977
|
+
"X-Title": "KilnAI",
|
|
978
|
+
},
|
|
979
|
+
),
|
|
980
|
+
),
|
|
981
|
+
(
|
|
982
|
+
ModelProviderName.openai,
|
|
983
|
+
LiteLlmCoreConfig(additional_body_options={"api_key": "test-openai-key"}),
|
|
984
|
+
),
|
|
985
|
+
(
|
|
986
|
+
ModelProviderName.groq,
|
|
987
|
+
LiteLlmCoreConfig(additional_body_options={"api_key": "test-groq-key"}),
|
|
988
|
+
),
|
|
989
|
+
(
|
|
990
|
+
ModelProviderName.amazon_bedrock,
|
|
991
|
+
LiteLlmCoreConfig(
|
|
992
|
+
additional_body_options={
|
|
993
|
+
"aws_access_key_id": "test-aws-access-key",
|
|
994
|
+
"aws_secret_access_key": "test-aws-secret-key",
|
|
995
|
+
"aws_region_name": "us-west-2",
|
|
996
|
+
},
|
|
997
|
+
),
|
|
998
|
+
),
|
|
999
|
+
(
|
|
1000
|
+
ModelProviderName.ollama,
|
|
1001
|
+
LiteLlmCoreConfig(
|
|
1002
|
+
base_url="http://test-ollama:11434/v1",
|
|
1003
|
+
additional_body_options={"api_key": "NA"},
|
|
1004
|
+
),
|
|
1005
|
+
),
|
|
1006
|
+
(
|
|
1007
|
+
ModelProviderName.fireworks_ai,
|
|
1008
|
+
LiteLlmCoreConfig(
|
|
1009
|
+
additional_body_options={"api_key": "test-fireworks-key"}
|
|
1010
|
+
),
|
|
1011
|
+
),
|
|
1012
|
+
(
|
|
1013
|
+
ModelProviderName.anthropic,
|
|
1014
|
+
LiteLlmCoreConfig(
|
|
1015
|
+
additional_body_options={"api_key": "test-anthropic-key"}
|
|
1016
|
+
),
|
|
1017
|
+
),
|
|
1018
|
+
(
|
|
1019
|
+
ModelProviderName.gemini_api,
|
|
1020
|
+
LiteLlmCoreConfig(additional_body_options={"api_key": "test-gemini-key"}),
|
|
1021
|
+
),
|
|
1022
|
+
(
|
|
1023
|
+
ModelProviderName.vertex,
|
|
1024
|
+
LiteLlmCoreConfig(
|
|
1025
|
+
additional_body_options={
|
|
1026
|
+
"vertex_project": "test-vertex-project",
|
|
1027
|
+
"vertex_location": "us-central1",
|
|
1028
|
+
},
|
|
1029
|
+
),
|
|
1030
|
+
),
|
|
1031
|
+
(
|
|
1032
|
+
ModelProviderName.together_ai,
|
|
1033
|
+
LiteLlmCoreConfig(additional_body_options={"api_key": "test-together-key"}),
|
|
1034
|
+
),
|
|
1035
|
+
(
|
|
1036
|
+
ModelProviderName.azure_openai,
|
|
1037
|
+
LiteLlmCoreConfig(
|
|
1038
|
+
base_url="https://test.openai.azure.com",
|
|
1039
|
+
additional_body_options={
|
|
1040
|
+
"api_key": "test-azure-key",
|
|
1041
|
+
"api_version": "2025-02-01-preview",
|
|
1042
|
+
},
|
|
1043
|
+
),
|
|
1044
|
+
),
|
|
1045
|
+
(
|
|
1046
|
+
ModelProviderName.huggingface,
|
|
1047
|
+
LiteLlmCoreConfig(additional_body_options={"api_key": "test-hf-key"}),
|
|
1048
|
+
),
|
|
1049
|
+
(ModelProviderName.kiln_fine_tune, None),
|
|
1050
|
+
(ModelProviderName.kiln_custom_registry, None),
|
|
1051
|
+
],
|
|
1052
|
+
)
|
|
1053
|
+
def test_lite_llm_core_config_for_provider(
|
|
1054
|
+
mock_config_for_lite_llm_core_config, provider_name, expected_config
|
|
1055
|
+
):
|
|
1056
|
+
config = lite_llm_core_config_for_provider(provider_name)
|
|
1057
|
+
assert config == expected_config
|
|
1058
|
+
|
|
1059
|
+
|
|
1060
|
+
def test_lite_llm_core_config_for_provider_openai_compatible(
|
|
1061
|
+
mock_shared_config,
|
|
1062
|
+
):
|
|
1063
|
+
config = lite_llm_core_config_for_provider(
|
|
1064
|
+
ModelProviderName.openai_compatible, "no_key_provider"
|
|
1065
|
+
)
|
|
1066
|
+
assert config is not None
|
|
1067
|
+
assert config.base_url == "https://api.nokey.com"
|
|
1068
|
+
assert config.additional_body_options == {"api_key": "NA"}
|
|
1069
|
+
|
|
1070
|
+
|
|
1071
|
+
def test_lite_llm_core_config_for_provider_openai_compatible_with_openai_compatible_provider_name(
|
|
1072
|
+
mock_shared_config,
|
|
1073
|
+
):
|
|
1074
|
+
with pytest.raises(
|
|
1075
|
+
ValueError, match="OpenAI compatible provider requires a provider name"
|
|
1076
|
+
):
|
|
1077
|
+
lite_llm_core_config_for_provider(ModelProviderName.openai_compatible)
|
|
1078
|
+
|
|
1079
|
+
|
|
1080
|
+
def test_lite_llm_core_config_incorrect_openai_compatible_provider_name(
|
|
1081
|
+
mock_shared_config,
|
|
1082
|
+
):
|
|
1083
|
+
with pytest.raises(
|
|
1084
|
+
ValueError,
|
|
1085
|
+
match="OpenAI compatible provider provider_that_does_not_exist_in_compatible_openai_providers not found",
|
|
1086
|
+
):
|
|
1087
|
+
lite_llm_core_config_for_provider(
|
|
1088
|
+
ModelProviderName.openai_compatible,
|
|
1089
|
+
"provider_that_does_not_exist_in_compatible_openai_providers",
|
|
1090
|
+
)
|
|
1091
|
+
|
|
1092
|
+
|
|
1093
|
+
def test_lite_llm_core_config_for_provider_with_string(
|
|
1094
|
+
mock_config_for_lite_llm_core_config,
|
|
1095
|
+
):
|
|
1096
|
+
# test with a string instead of an enum
|
|
1097
|
+
config = lite_llm_core_config_for_provider("openai")
|
|
1098
|
+
assert config == LiteLlmCoreConfig(
|
|
1099
|
+
additional_body_options={"api_key": "test-openai-key"}
|
|
1100
|
+
)
|
|
1101
|
+
|
|
1102
|
+
|
|
1103
|
+
def test_lite_llm_core_config_for_provider_unknown_provider():
|
|
1104
|
+
with pytest.raises(ValueError, match="Unhandled enum value: unknown_provider"):
|
|
1105
|
+
lite_llm_core_config_for_provider("unknown_provider")
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
@patch.dict("os.environ", {"OPENROUTER_BASE_URL": "https://custom-openrouter.com"})
|
|
1109
|
+
def test_lite_llm_core_config_for_provider_openrouter_custom_url(
|
|
1110
|
+
mock_config_for_lite_llm_core_config,
|
|
1111
|
+
):
|
|
1112
|
+
config = lite_llm_core_config_for_provider(ModelProviderName.openrouter)
|
|
1113
|
+
assert config is not None
|
|
1114
|
+
assert config.base_url == "https://custom-openrouter.com"
|
|
1115
|
+
|
|
1116
|
+
|
|
1117
|
+
def test_lite_llm_core_config_for_provider_ollama_default_url(
|
|
1118
|
+
mock_config_for_lite_llm_core_config,
|
|
1119
|
+
):
|
|
1120
|
+
# Override the mock to return None for ollama_base_url
|
|
1121
|
+
mock_config_for_lite_llm_core_config.shared.return_value.ollama_base_url = None
|
|
1122
|
+
|
|
1123
|
+
config = lite_llm_core_config_for_provider(ModelProviderName.ollama)
|
|
1124
|
+
assert config is not None
|
|
1125
|
+
assert config.base_url == "http://localhost:11434/v1"
|
|
1126
|
+
|
|
1127
|
+
|
|
937
1128
|
@pytest.mark.asyncio
|
|
938
1129
|
async def test_provider_enabled_docker_model_runner_success():
|
|
939
1130
|
"""Test provider_enabled for Docker Model Runner with successful connection"""
|