kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (133) hide show
  1. kiln_ai/adapters/__init__.py +6 -0
  2. kiln_ai/adapters/adapter_registry.py +43 -226
  3. kiln_ai/adapters/chunkers/__init__.py +13 -0
  4. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  5. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  6. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  7. kiln_ai/adapters/chunkers/helpers.py +23 -0
  8. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  9. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  10. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  11. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  12. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  13. kiln_ai/adapters/embedding/__init__.py +0 -0
  14. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  15. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  16. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  17. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  18. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  19. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  20. kiln_ai/adapters/eval/eval_runner.py +6 -2
  21. kiln_ai/adapters/eval/test_base_eval.py +1 -3
  22. kiln_ai/adapters/eval/test_g_eval.py +1 -1
  23. kiln_ai/adapters/extractors/__init__.py +18 -0
  24. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  25. kiln_ai/adapters/extractors/encoding.py +20 -0
  26. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  27. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  28. kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
  29. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  30. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  31. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  32. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  33. kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
  34. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  35. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  36. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  37. kiln_ai/adapters/ml_embedding_model_list.py +494 -0
  38. kiln_ai/adapters/ml_model_list.py +876 -18
  39. kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
  40. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
  41. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
  42. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
  43. kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
  44. kiln_ai/adapters/ollama_tools.py +69 -12
  45. kiln_ai/adapters/provider_tools.py +190 -46
  46. kiln_ai/adapters/rag/deduplication.py +49 -0
  47. kiln_ai/adapters/rag/progress.py +252 -0
  48. kiln_ai/adapters/rag/rag_runners.py +844 -0
  49. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  50. kiln_ai/adapters/rag/test_progress.py +785 -0
  51. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  52. kiln_ai/adapters/remote_config.py +80 -8
  53. kiln_ai/adapters/test_adapter_registry.py +579 -86
  54. kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
  55. kiln_ai/adapters/test_ml_model_list.py +202 -0
  56. kiln_ai/adapters/test_ollama_tools.py +340 -1
  57. kiln_ai/adapters/test_prompt_builders.py +1 -1
  58. kiln_ai/adapters/test_provider_tools.py +199 -8
  59. kiln_ai/adapters/test_remote_config.py +551 -56
  60. kiln_ai/adapters/vector_store/__init__.py +1 -0
  61. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  62. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  63. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  64. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  65. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  66. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  67. kiln_ai/datamodel/__init__.py +16 -13
  68. kiln_ai/datamodel/basemodel.py +201 -4
  69. kiln_ai/datamodel/chunk.py +158 -0
  70. kiln_ai/datamodel/datamodel_enums.py +27 -0
  71. kiln_ai/datamodel/embedding.py +64 -0
  72. kiln_ai/datamodel/external_tool_server.py +206 -54
  73. kiln_ai/datamodel/extraction.py +317 -0
  74. kiln_ai/datamodel/project.py +33 -1
  75. kiln_ai/datamodel/rag.py +79 -0
  76. kiln_ai/datamodel/task.py +5 -0
  77. kiln_ai/datamodel/task_output.py +41 -11
  78. kiln_ai/datamodel/test_attachment.py +649 -0
  79. kiln_ai/datamodel/test_basemodel.py +270 -14
  80. kiln_ai/datamodel/test_chunk_models.py +317 -0
  81. kiln_ai/datamodel/test_dataset_split.py +1 -1
  82. kiln_ai/datamodel/test_datasource.py +50 -0
  83. kiln_ai/datamodel/test_embedding_models.py +448 -0
  84. kiln_ai/datamodel/test_eval_model.py +6 -6
  85. kiln_ai/datamodel/test_external_tool_server.py +534 -152
  86. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  87. kiln_ai/datamodel/test_extraction_model.py +501 -0
  88. kiln_ai/datamodel/test_rag.py +641 -0
  89. kiln_ai/datamodel/test_task.py +35 -1
  90. kiln_ai/datamodel/test_tool_id.py +187 -1
  91. kiln_ai/datamodel/test_vector_store.py +320 -0
  92. kiln_ai/datamodel/tool_id.py +58 -0
  93. kiln_ai/datamodel/vector_store.py +141 -0
  94. kiln_ai/tools/base_tool.py +12 -3
  95. kiln_ai/tools/built_in_tools/math_tools.py +12 -4
  96. kiln_ai/tools/kiln_task_tool.py +158 -0
  97. kiln_ai/tools/mcp_server_tool.py +2 -2
  98. kiln_ai/tools/mcp_session_manager.py +51 -22
  99. kiln_ai/tools/rag_tools.py +164 -0
  100. kiln_ai/tools/test_kiln_task_tool.py +527 -0
  101. kiln_ai/tools/test_mcp_server_tool.py +4 -15
  102. kiln_ai/tools/test_mcp_session_manager.py +187 -227
  103. kiln_ai/tools/test_rag_tools.py +929 -0
  104. kiln_ai/tools/test_tool_registry.py +290 -7
  105. kiln_ai/tools/tool_registry.py +69 -16
  106. kiln_ai/utils/__init__.py +3 -0
  107. kiln_ai/utils/async_job_runner.py +62 -17
  108. kiln_ai/utils/config.py +2 -2
  109. kiln_ai/utils/env.py +15 -0
  110. kiln_ai/utils/filesystem.py +14 -0
  111. kiln_ai/utils/filesystem_cache.py +60 -0
  112. kiln_ai/utils/litellm.py +94 -0
  113. kiln_ai/utils/lock.py +100 -0
  114. kiln_ai/utils/mime_type.py +38 -0
  115. kiln_ai/utils/open_ai_types.py +19 -2
  116. kiln_ai/utils/pdf_utils.py +59 -0
  117. kiln_ai/utils/test_async_job_runner.py +151 -35
  118. kiln_ai/utils/test_env.py +142 -0
  119. kiln_ai/utils/test_filesystem_cache.py +316 -0
  120. kiln_ai/utils/test_litellm.py +206 -0
  121. kiln_ai/utils/test_lock.py +185 -0
  122. kiln_ai/utils/test_mime_type.py +66 -0
  123. kiln_ai/utils/test_open_ai_types.py +88 -12
  124. kiln_ai/utils/test_pdf_utils.py +86 -0
  125. kiln_ai/utils/test_uuid.py +111 -0
  126. kiln_ai/utils/test_validation.py +524 -0
  127. kiln_ai/utils/uuid.py +9 -0
  128. kiln_ai/utils/validation.py +90 -0
  129. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
  130. kiln_ai-0.22.0.dist-info/RECORD +213 -0
  131. kiln_ai-0.20.1.dist-info/RECORD +0 -138
  132. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
  133. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,13 +1,85 @@
1
1
  import json
2
+ from unittest.mock import patch
2
3
 
4
+ import pytest
5
+
6
+ from kiln_ai.adapters.ml_embedding_model_list import (
7
+ KilnEmbeddingModel,
8
+ KilnEmbeddingModelProvider,
9
+ )
10
+
11
+ # Mock data for testing - using proper Pydantic model instances
12
+ from kiln_ai.adapters.ml_model_list import KilnModel, KilnModelProvider
3
13
  from kiln_ai.adapters.ollama_tools import (
4
14
  OllamaConnection,
15
+ ollama_embedding_model_installed,
5
16
  ollama_model_installed,
6
17
  parse_ollama_tags,
7
18
  )
19
+ from kiln_ai.datamodel.datamodel_enums import ModelProviderName
8
20
 
21
+ MOCK_BUILT_IN_MODELS = [
22
+ KilnModel(
23
+ family="phi",
24
+ name="phi3.5",
25
+ friendly_name="phi3.5",
26
+ providers=[
27
+ KilnModelProvider(
28
+ name=ModelProviderName.ollama,
29
+ model_id="phi3.5",
30
+ ollama_model_aliases=None,
31
+ )
32
+ ],
33
+ ),
34
+ KilnModel(
35
+ family="gemma",
36
+ name="gemma2",
37
+ friendly_name="gemma2",
38
+ providers=[
39
+ KilnModelProvider(
40
+ name=ModelProviderName.ollama,
41
+ model_id="gemma2:2b",
42
+ ollama_model_aliases=None,
43
+ )
44
+ ],
45
+ ),
46
+ KilnModel(
47
+ family="llama",
48
+ name="llama3.1",
49
+ friendly_name="llama3.1",
50
+ providers=[
51
+ KilnModelProvider(
52
+ name=ModelProviderName.ollama,
53
+ model_id="llama3.1",
54
+ ollama_model_aliases=None,
55
+ )
56
+ ],
57
+ ),
58
+ ]
9
59
 
10
- def test_parse_ollama_tags_no_models():
60
+ MOCK_BUILT_IN_EMBEDDING_MODELS = [
61
+ KilnEmbeddingModel(
62
+ family="gemma",
63
+ name="embeddinggemma",
64
+ friendly_name="embeddinggemma",
65
+ providers=[
66
+ KilnEmbeddingModelProvider(
67
+ name=ModelProviderName.ollama,
68
+ model_id="embeddinggemma:300m",
69
+ n_dimensions=768,
70
+ ollama_model_aliases=["embeddinggemma"],
71
+ )
72
+ ],
73
+ ),
74
+ ]
75
+
76
+
77
+ @patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
78
+ @patch(
79
+ "kiln_ai.adapters.ollama_tools.built_in_embedding_models",
80
+ MOCK_BUILT_IN_EMBEDDING_MODELS,
81
+ )
82
+ def test_parse_ollama_tags_models():
11
83
  json_response = '{"models":[{"name":"scosman_net","model":"scosman_net:latest"},{"name":"phi3.5:latest","model":"phi3.5:latest","modified_at":"2024-10-02T12:04:35.191519822-04:00","size":2176178843,"digest":"61819fb370a3c1a9be6694869331e5f85f867a079e9271d66cb223acb81d04ba","details":{"parent_model":"","format":"gguf","family":"phi3","families":["phi3"],"parameter_size":"3.8B","quantization_level":"Q4_0"}},{"name":"gemma2:2b","model":"gemma2:2b","modified_at":"2024-09-09T16:46:38.64348929-04:00","size":1629518495,"digest":"8ccf136fdd5298f3ffe2d69862750ea7fb56555fa4d5b18c04e3fa4d82ee09d7","details":{"parent_model":"","format":"gguf","family":"gemma2","families":["gemma2"],"parameter_size":"2.6B","quantization_level":"Q4_0"}},{"name":"llama3.1:latest","model":"llama3.1:latest","modified_at":"2024-09-01T17:19:43.481523695-04:00","size":4661230720,"digest":"f66fc8dc39ea206e03ff6764fcc696b1b4dfb693f0b6ef751731dd4e6269046e","details":{"parent_model":"","format":"gguf","family":"llama","families":["llama"],"parameter_size":"8.0B","quantization_level":"Q4_0"}}]}'
12
84
  tags = json.loads(json_response)
13
85
  conn = parse_ollama_tags(tags)
@@ -16,7 +88,52 @@ def test_parse_ollama_tags_no_models():
16
88
  assert "llama3.1:latest" in conn.supported_models
17
89
  assert "scosman_net:latest" in conn.untested_models
18
90
 
91
+ # there should be no embedding models because the tags response does not include any embedding models
92
+ # that are in the built-in embedding models list
93
+ assert len(conn.supported_embedding_models) == 0
19
94
 
95
+
96
+ @patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
97
+ @patch(
98
+ "kiln_ai.adapters.ollama_tools.built_in_embedding_models",
99
+ MOCK_BUILT_IN_EMBEDDING_MODELS,
100
+ )
101
+ @pytest.mark.parametrize("json_response", ["{}", '{"models": []}'])
102
+ def test_parse_ollama_tags_no_models(json_response):
103
+ tags = json.loads(json_response)
104
+ conn = parse_ollama_tags(tags)
105
+ assert (
106
+ conn.message
107
+ == "Ollama is running, but no supported models are installed. Install one or more supported model, like 'ollama pull phi3.5'."
108
+ )
109
+ assert len(conn.supported_models) == 0
110
+ assert len(conn.untested_models) == 0
111
+ assert len(conn.supported_embedding_models) == 0
112
+
113
+
114
+ @patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
115
+ @patch(
116
+ "kiln_ai.adapters.ollama_tools.built_in_embedding_models",
117
+ MOCK_BUILT_IN_EMBEDDING_MODELS,
118
+ )
119
+ def test_parse_ollama_tags_empty_models():
120
+ """Test parsing Ollama tags response with empty models list"""
121
+ json_response = '{"models": []}'
122
+ tags = json.loads(json_response)
123
+ conn = parse_ollama_tags(tags)
124
+
125
+ # Check that connection indicates no supported models
126
+ assert conn.supported_models == []
127
+ assert conn.untested_models == []
128
+ assert conn.supported_embedding_models == []
129
+ assert "no supported models are installed" in conn.message
130
+
131
+
132
+ @patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
133
+ @patch(
134
+ "kiln_ai.adapters.ollama_tools.built_in_embedding_models",
135
+ MOCK_BUILT_IN_EMBEDDING_MODELS,
136
+ )
20
137
  def test_parse_ollama_tags_only_untested_models():
21
138
  json_response = '{"models":[{"name":"scosman_net","model":"scosman_net:latest"}]}'
22
139
  tags = json.loads(json_response)
@@ -24,12 +141,17 @@ def test_parse_ollama_tags_only_untested_models():
24
141
  assert conn.supported_models == []
25
142
  assert conn.untested_models == ["scosman_net:latest"]
26
143
 
144
+ # there should be no embedding models because the tags response does not include any embedding models
145
+ # that are in the built-in embedding models list
146
+ assert len(conn.supported_embedding_models) == 0
147
+
27
148
 
28
149
  def test_ollama_model_installed():
29
150
  conn = OllamaConnection(
30
151
  supported_models=["phi3.5:latest", "gemma2:2b", "llama3.1:latest"],
31
152
  message="Connected",
32
153
  untested_models=["scosman_net:latest"],
154
+ supported_embedding_models=["embeddinggemma:300m"],
33
155
  )
34
156
  assert ollama_model_installed(conn, "phi3.5:latest")
35
157
  assert ollama_model_installed(conn, "phi3.5")
@@ -39,3 +161,220 @@ def test_ollama_model_installed():
39
161
  assert ollama_model_installed(conn, "scosman_net:latest")
40
162
  assert ollama_model_installed(conn, "scosman_net")
41
163
  assert not ollama_model_installed(conn, "unknown_model")
164
+
165
+ # use the ollama_embedding_model_installed for testing embedding models installed, not ollama_model_installed
166
+ assert not ollama_model_installed(conn, "embeddinggemma:300m")
167
+ assert not ollama_model_installed(conn, "embeddinggemma")
168
+
169
+
170
+ def test_ollama_model_installed_embedding_models():
171
+ conn = OllamaConnection(
172
+ message="Connected",
173
+ supported_models=["phi3.5:latest", "gemma2:2b", "llama3.1:latest"],
174
+ untested_models=["scosman_net:latest"],
175
+ supported_embedding_models=["embeddinggemma:300m", "embeddinggemma:latest"],
176
+ )
177
+
178
+ assert ollama_embedding_model_installed(conn, "embeddinggemma:300m")
179
+ assert ollama_embedding_model_installed(conn, "embeddinggemma:latest")
180
+ assert not ollama_embedding_model_installed(conn, "unknown_embedding")
181
+
182
+ # use the ollama_model_installed for testing regular models installed, not ollama_embedding_model_installed
183
+ assert not ollama_embedding_model_installed(conn, "phi3.5:latest")
184
+ assert not ollama_embedding_model_installed(conn, "gemma2:2b")
185
+ assert not ollama_embedding_model_installed(conn, "llama3.1:latest")
186
+ assert not ollama_embedding_model_installed(conn, "scosman_net:latest")
187
+
188
+
189
+ @patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
190
+ @patch(
191
+ "kiln_ai.adapters.ollama_tools.built_in_embedding_models",
192
+ MOCK_BUILT_IN_EMBEDDING_MODELS,
193
+ )
194
+ def test_parse_ollama_tags_with_embedding_models():
195
+ """Test parsing Ollama tags response that includes embedding models"""
196
+ json_response = """{
197
+ "models": [
198
+ {
199
+ "name": "phi3.5:latest",
200
+ "model": "phi3.5:latest"
201
+ },
202
+ {
203
+ "name": "embeddinggemma:300m",
204
+ "model": "embeddinggemma:300m"
205
+ },
206
+ {
207
+ "name": "embeddinggemma:latest",
208
+ "model": "embeddinggemma:latest"
209
+ },
210
+ {
211
+ "name": "unknown_embedding:latest",
212
+ "model": "unknown_embedding:latest"
213
+ }
214
+ ]
215
+ }"""
216
+ tags = json.loads(json_response)
217
+ conn = parse_ollama_tags(tags)
218
+
219
+ # Check that embedding models are properly categorized
220
+ assert "embeddinggemma:300m" in conn.supported_embedding_models
221
+ assert "embeddinggemma:latest" in conn.supported_embedding_models
222
+
223
+ # Check that regular models are still parsed correctly
224
+ assert "phi3.5:latest" in conn.supported_models
225
+
226
+ # Check that embedding models are NOT in the main model lists
227
+ assert "embeddinggemma:300m" not in conn.supported_models
228
+ assert "embeddinggemma:latest" not in conn.supported_models
229
+ assert "embeddinggemma:300m" not in conn.untested_models
230
+ assert "embeddinggemma:latest" not in conn.untested_models
231
+
232
+ # we assume the unknown models are normal models, not embedding models - because
233
+ # we don't support untested embedding models currently
234
+ assert "unknown_embedding:latest" not in conn.supported_embedding_models
235
+ assert "unknown_embedding:latest" in conn.untested_models
236
+
237
+
238
+ @patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
239
+ @patch(
240
+ "kiln_ai.adapters.ollama_tools.built_in_embedding_models",
241
+ MOCK_BUILT_IN_EMBEDDING_MODELS,
242
+ )
243
+ def test_parse_ollama_tags_embedding_model_aliases():
244
+ """Test parsing Ollama tags response with embedding model aliases"""
245
+ json_response = """{
246
+ "models": [
247
+ {
248
+ "name": "embeddinggemma",
249
+ "model": "embeddinggemma"
250
+ }
251
+ ]
252
+ }"""
253
+ tags = json.loads(json_response)
254
+ conn = parse_ollama_tags(tags)
255
+
256
+ # Check that embedding model aliases are recognized
257
+ assert "embeddinggemma" in conn.supported_embedding_models
258
+
259
+ # Check that embedding model aliases are NOT in the main model lists
260
+ assert "embeddinggemma" not in conn.supported_models
261
+ assert "embeddinggemma" not in conn.untested_models
262
+
263
+ assert len(conn.supported_models) == 0
264
+ assert len(conn.untested_models) == 0
265
+ assert len(conn.supported_embedding_models) == 1
266
+
267
+
268
+ @patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
269
+ @patch(
270
+ "kiln_ai.adapters.ollama_tools.built_in_embedding_models",
271
+ MOCK_BUILT_IN_EMBEDDING_MODELS,
272
+ )
273
+ def test_parse_ollama_tags_only_embedding_models():
274
+ """Test parsing Ollama tags response with only embedding models"""
275
+ json_response = """{
276
+ "models": [
277
+ {
278
+ "name": "embeddinggemma:300m",
279
+ "model": "embeddinggemma:300m"
280
+ }
281
+ ]
282
+ }"""
283
+ tags = json.loads(json_response)
284
+ conn = parse_ollama_tags(tags)
285
+
286
+ # Check that embedding models are found but no regular models
287
+ assert "embeddinggemma:300m" in conn.supported_embedding_models
288
+ assert conn.supported_models == []
289
+ assert conn.untested_models == []
290
+
291
+ # Check that embedding models are NOT in the main model lists
292
+ assert "embeddinggemma:300m" not in conn.supported_models
293
+ assert "embeddinggemma:300m" not in conn.untested_models
294
+
295
+
296
+ def test_ollama_connection_all_embedding_models():
297
+ """Test OllamaConnection.all_embedding_models() method"""
298
+ conn = OllamaConnection(
299
+ message="Connected",
300
+ supported_models=["phi3.5:latest"],
301
+ untested_models=["unknown:latest"],
302
+ supported_embedding_models=["embeddinggemma:300m", "embeddinggemma:latest"],
303
+ )
304
+
305
+ embedding_models = conn.all_embedding_models()
306
+ assert embedding_models == ["embeddinggemma:300m", "embeddinggemma:latest"]
307
+
308
+
309
+ def test_ollama_connection_empty_embedding_models():
310
+ """Test OllamaConnection.all_embedding_models() with empty list"""
311
+ conn = OllamaConnection(
312
+ message="Connected",
313
+ supported_models=["phi3.5:latest"],
314
+ untested_models=["unknown:latest"],
315
+ supported_embedding_models=[],
316
+ )
317
+
318
+ embedding_models = conn.all_embedding_models()
319
+ assert embedding_models == []
320
+
321
+
322
+ @patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
323
+ @patch(
324
+ "kiln_ai.adapters.ollama_tools.built_in_embedding_models",
325
+ MOCK_BUILT_IN_EMBEDDING_MODELS,
326
+ )
327
+ def test_parse_ollama_tags_mixed_models_and_embeddings():
328
+ """Test parsing response with mix of regular models, embedding models, and unknown models"""
329
+ json_response = """{
330
+ "models": [
331
+ {
332
+ "name": "phi3.5:latest",
333
+ "model": "phi3.5:latest"
334
+ },
335
+ {
336
+ "name": "gemma2:2b",
337
+ "model": "gemma2:2b"
338
+ },
339
+ {
340
+ "name": "embeddinggemma:300m",
341
+ "model": "embeddinggemma:300m"
342
+ },
343
+ {
344
+ "name": "embeddinggemma",
345
+ "model": "embeddinggemma"
346
+ },
347
+ {
348
+ "name": "unknown_model:latest",
349
+ "model": "unknown_model:latest"
350
+ },
351
+ {
352
+ "name": "unknown_embedding:latest",
353
+ "model": "unknown_embedding:latest"
354
+ }
355
+ ]
356
+ }"""
357
+ tags = json.loads(json_response)
358
+ conn = parse_ollama_tags(tags)
359
+
360
+ # Check regular models
361
+ assert "phi3.5:latest" in conn.supported_models
362
+ assert "gemma2:2b" in conn.supported_models
363
+ assert "unknown_model:latest" in conn.untested_models
364
+
365
+ # Check embedding models
366
+ assert "embeddinggemma:300m" in conn.supported_embedding_models
367
+ assert "embeddinggemma" in conn.supported_embedding_models
368
+
369
+ # Check that embedding models are NOT in the main model lists
370
+ assert "embeddinggemma:300m" not in conn.supported_models
371
+ assert "embeddinggemma" not in conn.supported_models
372
+ assert "embeddinggemma:300m" not in conn.untested_models
373
+ assert "embeddinggemma" not in conn.untested_models
374
+
375
+ # Unknown embedding models should not appear in supported_embedding_models
376
+ assert "unknown_embedding:latest" not in conn.supported_embedding_models
377
+
378
+ # Unknown embedding models should appear in untested_models (since they're not recognized as embeddings)
379
+ assert "unknown_embedding:latest" not in conn.supported_models
380
+ assert "unknown_embedding:latest" in conn.untested_models
@@ -359,7 +359,7 @@ def test_prompt_builder_from_id(task_with_examples):
359
359
 
360
360
  with pytest.raises(
361
361
  ValueError,
362
- match="Invalid fine-tune ID format. Expected 'project_id::task_id::fine_tune_id'",
362
+ match=r"Invalid fine-tune ID format. Expected 'project_id::task_id::fine_tune_id'",
363
363
  ):
364
364
  prompt_builder_from_id("fine_tune_prompt::123", task)
365
365
 
@@ -2,6 +2,7 @@ from unittest.mock import AsyncMock, Mock, patch
2
2
 
3
3
  import pytest
4
4
 
5
+ from kiln_ai.adapters.adapter_registry import litellm_core_provider_config
5
6
  from kiln_ai.adapters.docker_model_runner_tools import DockerModelRunnerConnection
6
7
  from kiln_ai.adapters.ml_model_list import (
7
8
  KilnModel,
@@ -11,6 +12,7 @@ from kiln_ai.adapters.ml_model_list import (
11
12
  )
12
13
  from kiln_ai.adapters.ollama_tools import OllamaConnection
13
14
  from kiln_ai.adapters.provider_tools import (
15
+ LiteLlmCoreConfig,
14
16
  builtin_model_from,
15
17
  check_provider_warnings,
16
18
  core_provider,
@@ -19,7 +21,7 @@ from kiln_ai.adapters.provider_tools import (
19
21
  finetune_provider_model,
20
22
  get_model_and_provider,
21
23
  kiln_model_provider_from,
22
- lite_llm_config_for_openai_compatible,
24
+ lite_llm_core_config_for_provider,
23
25
  lite_llm_provider_model,
24
26
  parse_custom_model_id,
25
27
  provider_enabled,
@@ -604,7 +606,7 @@ def test_openai_compatible_provider_config(mock_shared_config):
604
606
  """Test successful creation of an OpenAI compatible provider"""
605
607
  model_id = "test_provider::gpt-4"
606
608
 
607
- config = lite_llm_config_for_openai_compatible(
609
+ config = litellm_core_provider_config(
608
610
  RunConfigProperties(
609
611
  model_name=model_id,
610
612
  model_provider_name=ModelProviderName.openai_compatible,
@@ -639,10 +641,10 @@ def test_lite_llm_config_no_api_key(mock_shared_config):
639
641
  """Test provider creation without API key (should work as some providers don't require it, but should pass NA to LiteLLM as it requires one)"""
640
642
  model_id = "no_key_provider::gpt-4"
641
643
 
642
- config = lite_llm_config_for_openai_compatible(
644
+ config = litellm_core_provider_config(
643
645
  RunConfigProperties(
644
646
  model_name=model_id,
645
- model_provider_name=ModelProviderName.openai,
647
+ model_provider_name=ModelProviderName.openai_compatible,
646
648
  prompt_id="simple_prompt_builder",
647
649
  structured_output_mode="json_schema",
648
650
  )
@@ -660,7 +662,7 @@ def test_lite_llm_config_no_api_key(mock_shared_config):
660
662
  def test_lite_llm_config_invalid_id():
661
663
  """Test handling of invalid model ID format"""
662
664
  with pytest.raises(ValueError) as exc_info:
663
- lite_llm_config_for_openai_compatible(
665
+ litellm_core_provider_config(
664
666
  RunConfigProperties(
665
667
  model_name="invalid-id-format",
666
668
  model_provider_name=ModelProviderName.openai_compatible,
@@ -678,7 +680,7 @@ def test_lite_llm_config_no_providers(mock_shared_config):
678
680
  mock_shared_config.return_value.openai_compatible_providers = None
679
681
 
680
682
  with pytest.raises(ValueError) as exc_info:
681
- lite_llm_config_for_openai_compatible(
683
+ litellm_core_provider_config(
682
684
  RunConfigProperties(
683
685
  model_name="test_provider::gpt-4",
684
686
  model_provider_name=ModelProviderName.openai_compatible,
@@ -692,7 +694,7 @@ def test_lite_llm_config_no_providers(mock_shared_config):
692
694
  def test_lite_llm_config_provider_not_found(mock_shared_config):
693
695
  """Test handling of non-existent provider"""
694
696
  with pytest.raises(ValueError) as exc_info:
695
- lite_llm_config_for_openai_compatible(
697
+ litellm_core_provider_config(
696
698
  RunConfigProperties(
697
699
  model_name="unknown_provider::gpt-4",
698
700
  model_provider_name=ModelProviderName.openai_compatible,
@@ -715,7 +717,7 @@ def test_lite_llm_config_no_base_url(mock_shared_config):
715
717
  ]
716
718
 
717
719
  with pytest.raises(ValueError) as exc_info:
718
- lite_llm_config_for_openai_compatible(
720
+ litellm_core_provider_config(
719
721
  RunConfigProperties(
720
722
  model_name="test_provider::gpt-4",
721
723
  model_provider_name=ModelProviderName.openai_compatible,
@@ -934,6 +936,195 @@ def test_finetune_provider_model_vertex_ai(mock_project, mock_task, mock_finetun
934
936
  assert provider.structured_output_mode == StructuredOutputMode.json_mode
935
937
 
936
938
 
939
+ @pytest.fixture
940
+ def mock_config_for_lite_llm_core_config():
941
+ with patch("kiln_ai.adapters.provider_tools.Config") as mock:
942
+ config_instance = Mock()
943
+ mock.shared.return_value = config_instance
944
+
945
+ # Set up all the config values
946
+ config_instance.open_router_api_key = "test-openrouter-key"
947
+ config_instance.open_ai_api_key = "test-openai-key"
948
+ config_instance.groq_api_key = "test-groq-key"
949
+ config_instance.bedrock_access_key = "test-aws-access-key"
950
+ config_instance.bedrock_secret_key = "test-aws-secret-key"
951
+ config_instance.ollama_base_url = "http://test-ollama:11434"
952
+ config_instance.fireworks_api_key = "test-fireworks-key"
953
+ config_instance.anthropic_api_key = "test-anthropic-key"
954
+ config_instance.gemini_api_key = "test-gemini-key"
955
+ config_instance.vertex_project_id = "test-vertex-project"
956
+ config_instance.vertex_location = "us-central1"
957
+ config_instance.together_api_key = "test-together-key"
958
+ config_instance.azure_openai_api_key = "test-azure-key"
959
+ config_instance.azure_openai_endpoint = "https://test.openai.azure.com"
960
+ config_instance.huggingface_api_key = "test-hf-key"
961
+
962
+ yield mock
963
+
964
+
965
+ @pytest.mark.parametrize(
966
+ "provider_name,expected_config",
967
+ [
968
+ (
969
+ ModelProviderName.openrouter,
970
+ LiteLlmCoreConfig(
971
+ base_url="https://openrouter.ai/api/v1",
972
+ additional_body_options={
973
+ "api_key": "test-openrouter-key",
974
+ },
975
+ default_headers={
976
+ "HTTP-Referer": "https://kiln.tech/openrouter",
977
+ "X-Title": "KilnAI",
978
+ },
979
+ ),
980
+ ),
981
+ (
982
+ ModelProviderName.openai,
983
+ LiteLlmCoreConfig(additional_body_options={"api_key": "test-openai-key"}),
984
+ ),
985
+ (
986
+ ModelProviderName.groq,
987
+ LiteLlmCoreConfig(additional_body_options={"api_key": "test-groq-key"}),
988
+ ),
989
+ (
990
+ ModelProviderName.amazon_bedrock,
991
+ LiteLlmCoreConfig(
992
+ additional_body_options={
993
+ "aws_access_key_id": "test-aws-access-key",
994
+ "aws_secret_access_key": "test-aws-secret-key",
995
+ "aws_region_name": "us-west-2",
996
+ },
997
+ ),
998
+ ),
999
+ (
1000
+ ModelProviderName.ollama,
1001
+ LiteLlmCoreConfig(
1002
+ base_url="http://test-ollama:11434/v1",
1003
+ additional_body_options={"api_key": "NA"},
1004
+ ),
1005
+ ),
1006
+ (
1007
+ ModelProviderName.fireworks_ai,
1008
+ LiteLlmCoreConfig(
1009
+ additional_body_options={"api_key": "test-fireworks-key"}
1010
+ ),
1011
+ ),
1012
+ (
1013
+ ModelProviderName.anthropic,
1014
+ LiteLlmCoreConfig(
1015
+ additional_body_options={"api_key": "test-anthropic-key"}
1016
+ ),
1017
+ ),
1018
+ (
1019
+ ModelProviderName.gemini_api,
1020
+ LiteLlmCoreConfig(additional_body_options={"api_key": "test-gemini-key"}),
1021
+ ),
1022
+ (
1023
+ ModelProviderName.vertex,
1024
+ LiteLlmCoreConfig(
1025
+ additional_body_options={
1026
+ "vertex_project": "test-vertex-project",
1027
+ "vertex_location": "us-central1",
1028
+ },
1029
+ ),
1030
+ ),
1031
+ (
1032
+ ModelProviderName.together_ai,
1033
+ LiteLlmCoreConfig(additional_body_options={"api_key": "test-together-key"}),
1034
+ ),
1035
+ (
1036
+ ModelProviderName.azure_openai,
1037
+ LiteLlmCoreConfig(
1038
+ base_url="https://test.openai.azure.com",
1039
+ additional_body_options={
1040
+ "api_key": "test-azure-key",
1041
+ "api_version": "2025-02-01-preview",
1042
+ },
1043
+ ),
1044
+ ),
1045
+ (
1046
+ ModelProviderName.huggingface,
1047
+ LiteLlmCoreConfig(additional_body_options={"api_key": "test-hf-key"}),
1048
+ ),
1049
+ (ModelProviderName.kiln_fine_tune, None),
1050
+ (ModelProviderName.kiln_custom_registry, None),
1051
+ ],
1052
+ )
1053
+ def test_lite_llm_core_config_for_provider(
1054
+ mock_config_for_lite_llm_core_config, provider_name, expected_config
1055
+ ):
1056
+ config = lite_llm_core_config_for_provider(provider_name)
1057
+ assert config == expected_config
1058
+
1059
+
1060
+ def test_lite_llm_core_config_for_provider_openai_compatible(
1061
+ mock_shared_config,
1062
+ ):
1063
+ config = lite_llm_core_config_for_provider(
1064
+ ModelProviderName.openai_compatible, "no_key_provider"
1065
+ )
1066
+ assert config is not None
1067
+ assert config.base_url == "https://api.nokey.com"
1068
+ assert config.additional_body_options == {"api_key": "NA"}
1069
+
1070
+
1071
+ def test_lite_llm_core_config_for_provider_openai_compatible_with_openai_compatible_provider_name(
1072
+ mock_shared_config,
1073
+ ):
1074
+ with pytest.raises(
1075
+ ValueError, match="OpenAI compatible provider requires a provider name"
1076
+ ):
1077
+ lite_llm_core_config_for_provider(ModelProviderName.openai_compatible)
1078
+
1079
+
1080
+ def test_lite_llm_core_config_incorrect_openai_compatible_provider_name(
1081
+ mock_shared_config,
1082
+ ):
1083
+ with pytest.raises(
1084
+ ValueError,
1085
+ match="OpenAI compatible provider provider_that_does_not_exist_in_compatible_openai_providers not found",
1086
+ ):
1087
+ lite_llm_core_config_for_provider(
1088
+ ModelProviderName.openai_compatible,
1089
+ "provider_that_does_not_exist_in_compatible_openai_providers",
1090
+ )
1091
+
1092
+
1093
+ def test_lite_llm_core_config_for_provider_with_string(
1094
+ mock_config_for_lite_llm_core_config,
1095
+ ):
1096
+ # test with a string instead of an enum
1097
+ config = lite_llm_core_config_for_provider("openai")
1098
+ assert config == LiteLlmCoreConfig(
1099
+ additional_body_options={"api_key": "test-openai-key"}
1100
+ )
1101
+
1102
+
1103
+ def test_lite_llm_core_config_for_provider_unknown_provider():
1104
+ with pytest.raises(ValueError, match="Unhandled enum value: unknown_provider"):
1105
+ lite_llm_core_config_for_provider("unknown_provider")
1106
+
1107
+
1108
+ @patch.dict("os.environ", {"OPENROUTER_BASE_URL": "https://custom-openrouter.com"})
1109
+ def test_lite_llm_core_config_for_provider_openrouter_custom_url(
1110
+ mock_config_for_lite_llm_core_config,
1111
+ ):
1112
+ config = lite_llm_core_config_for_provider(ModelProviderName.openrouter)
1113
+ assert config is not None
1114
+ assert config.base_url == "https://custom-openrouter.com"
1115
+
1116
+
1117
+ def test_lite_llm_core_config_for_provider_ollama_default_url(
1118
+ mock_config_for_lite_llm_core_config,
1119
+ ):
1120
+ # Override the mock to return None for ollama_base_url
1121
+ mock_config_for_lite_llm_core_config.shared.return_value.ollama_base_url = None
1122
+
1123
+ config = lite_llm_core_config_for_provider(ModelProviderName.ollama)
1124
+ assert config is not None
1125
+ assert config.base_url == "http://localhost:11434/v1"
1126
+
1127
+
937
1128
  @pytest.mark.asyncio
938
1129
  async def test_provider_enabled_docker_model_runner_success():
939
1130
  """Test provider_enabled for Docker Model Runner with successful connection"""