kiln-ai 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (158) hide show
  1. kiln_ai/adapters/__init__.py +8 -2
  2. kiln_ai/adapters/adapter_registry.py +43 -208
  3. kiln_ai/adapters/chat/chat_formatter.py +8 -12
  4. kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
  5. kiln_ai/adapters/chunkers/__init__.py +13 -0
  6. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  7. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  8. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  9. kiln_ai/adapters/chunkers/helpers.py +23 -0
  10. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  11. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  12. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  13. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  14. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  15. kiln_ai/adapters/docker_model_runner_tools.py +119 -0
  16. kiln_ai/adapters/embedding/__init__.py +0 -0
  17. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  18. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  19. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  20. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  21. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  22. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  23. kiln_ai/adapters/eval/base_eval.py +2 -2
  24. kiln_ai/adapters/eval/eval_runner.py +9 -3
  25. kiln_ai/adapters/eval/g_eval.py +2 -2
  26. kiln_ai/adapters/eval/test_base_eval.py +2 -4
  27. kiln_ai/adapters/eval/test_g_eval.py +4 -5
  28. kiln_ai/adapters/extractors/__init__.py +18 -0
  29. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  30. kiln_ai/adapters/extractors/encoding.py +20 -0
  31. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  32. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  33. kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
  34. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  35. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  36. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  37. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  38. kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
  39. kiln_ai/adapters/fine_tune/__init__.py +1 -1
  40. kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
  41. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  42. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  43. kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
  44. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  45. kiln_ai/adapters/ml_embedding_model_list.py +192 -0
  46. kiln_ai/adapters/ml_model_list.py +761 -37
  47. kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
  48. kiln_ai/adapters/model_adapters/litellm_adapter.py +380 -138
  49. kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
  50. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -2
  51. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
  52. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
  53. kiln_ai/adapters/model_adapters/test_structured_output.py +113 -5
  54. kiln_ai/adapters/ollama_tools.py +69 -12
  55. kiln_ai/adapters/parsers/__init__.py +1 -1
  56. kiln_ai/adapters/provider_tools.py +205 -47
  57. kiln_ai/adapters/rag/deduplication.py +49 -0
  58. kiln_ai/adapters/rag/progress.py +252 -0
  59. kiln_ai/adapters/rag/rag_runners.py +844 -0
  60. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  61. kiln_ai/adapters/rag/test_progress.py +785 -0
  62. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  63. kiln_ai/adapters/remote_config.py +80 -8
  64. kiln_ai/adapters/repair/test_repair_task.py +12 -9
  65. kiln_ai/adapters/run_output.py +3 -0
  66. kiln_ai/adapters/test_adapter_registry.py +657 -85
  67. kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
  68. kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
  69. kiln_ai/adapters/test_ml_model_list.py +251 -1
  70. kiln_ai/adapters/test_ollama_tools.py +340 -1
  71. kiln_ai/adapters/test_prompt_adaptors.py +13 -6
  72. kiln_ai/adapters/test_prompt_builders.py +1 -1
  73. kiln_ai/adapters/test_provider_tools.py +254 -8
  74. kiln_ai/adapters/test_remote_config.py +651 -58
  75. kiln_ai/adapters/vector_store/__init__.py +1 -0
  76. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  77. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  78. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  79. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  80. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  81. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  82. kiln_ai/datamodel/__init__.py +39 -34
  83. kiln_ai/datamodel/basemodel.py +170 -1
  84. kiln_ai/datamodel/chunk.py +158 -0
  85. kiln_ai/datamodel/datamodel_enums.py +28 -0
  86. kiln_ai/datamodel/embedding.py +64 -0
  87. kiln_ai/datamodel/eval.py +1 -1
  88. kiln_ai/datamodel/external_tool_server.py +298 -0
  89. kiln_ai/datamodel/extraction.py +303 -0
  90. kiln_ai/datamodel/json_schema.py +25 -10
  91. kiln_ai/datamodel/project.py +40 -1
  92. kiln_ai/datamodel/rag.py +79 -0
  93. kiln_ai/datamodel/registry.py +0 -15
  94. kiln_ai/datamodel/run_config.py +62 -0
  95. kiln_ai/datamodel/task.py +2 -77
  96. kiln_ai/datamodel/task_output.py +6 -1
  97. kiln_ai/datamodel/task_run.py +41 -0
  98. kiln_ai/datamodel/test_attachment.py +649 -0
  99. kiln_ai/datamodel/test_basemodel.py +4 -4
  100. kiln_ai/datamodel/test_chunk_models.py +317 -0
  101. kiln_ai/datamodel/test_dataset_split.py +1 -1
  102. kiln_ai/datamodel/test_embedding_models.py +448 -0
  103. kiln_ai/datamodel/test_eval_model.py +6 -6
  104. kiln_ai/datamodel/test_example_models.py +175 -0
  105. kiln_ai/datamodel/test_external_tool_server.py +691 -0
  106. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  107. kiln_ai/datamodel/test_extraction_model.py +470 -0
  108. kiln_ai/datamodel/test_rag.py +641 -0
  109. kiln_ai/datamodel/test_registry.py +8 -3
  110. kiln_ai/datamodel/test_task.py +15 -47
  111. kiln_ai/datamodel/test_tool_id.py +320 -0
  112. kiln_ai/datamodel/test_vector_store.py +320 -0
  113. kiln_ai/datamodel/tool_id.py +105 -0
  114. kiln_ai/datamodel/vector_store.py +141 -0
  115. kiln_ai/tools/__init__.py +8 -0
  116. kiln_ai/tools/base_tool.py +82 -0
  117. kiln_ai/tools/built_in_tools/__init__.py +13 -0
  118. kiln_ai/tools/built_in_tools/math_tools.py +124 -0
  119. kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
  120. kiln_ai/tools/mcp_server_tool.py +95 -0
  121. kiln_ai/tools/mcp_session_manager.py +246 -0
  122. kiln_ai/tools/rag_tools.py +157 -0
  123. kiln_ai/tools/test_base_tools.py +199 -0
  124. kiln_ai/tools/test_mcp_server_tool.py +457 -0
  125. kiln_ai/tools/test_mcp_session_manager.py +1585 -0
  126. kiln_ai/tools/test_rag_tools.py +848 -0
  127. kiln_ai/tools/test_tool_registry.py +562 -0
  128. kiln_ai/tools/tool_registry.py +85 -0
  129. kiln_ai/utils/__init__.py +3 -0
  130. kiln_ai/utils/async_job_runner.py +62 -17
  131. kiln_ai/utils/config.py +24 -2
  132. kiln_ai/utils/env.py +15 -0
  133. kiln_ai/utils/filesystem.py +14 -0
  134. kiln_ai/utils/filesystem_cache.py +60 -0
  135. kiln_ai/utils/litellm.py +94 -0
  136. kiln_ai/utils/lock.py +100 -0
  137. kiln_ai/utils/mime_type.py +38 -0
  138. kiln_ai/utils/open_ai_types.py +94 -0
  139. kiln_ai/utils/pdf_utils.py +38 -0
  140. kiln_ai/utils/project_utils.py +17 -0
  141. kiln_ai/utils/test_async_job_runner.py +151 -35
  142. kiln_ai/utils/test_config.py +138 -1
  143. kiln_ai/utils/test_env.py +142 -0
  144. kiln_ai/utils/test_filesystem_cache.py +316 -0
  145. kiln_ai/utils/test_litellm.py +206 -0
  146. kiln_ai/utils/test_lock.py +185 -0
  147. kiln_ai/utils/test_mime_type.py +66 -0
  148. kiln_ai/utils/test_open_ai_types.py +131 -0
  149. kiln_ai/utils/test_pdf_utils.py +73 -0
  150. kiln_ai/utils/test_uuid.py +111 -0
  151. kiln_ai/utils/test_validation.py +524 -0
  152. kiln_ai/utils/uuid.py +9 -0
  153. kiln_ai/utils/validation.py +90 -0
  154. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +12 -5
  155. kiln_ai-0.21.0.dist-info/RECORD +211 -0
  156. kiln_ai-0.19.0.dist-info/RECORD +0 -115
  157. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
  158. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,8 +1,11 @@
1
+ from collections import Counter
2
+
1
3
  import pytest
2
4
 
3
5
  from kiln_ai.adapters.ml_model_list import (
4
6
  ModelName,
5
7
  built_in_models,
8
+ built_in_models_from_provider,
6
9
  default_structured_output_mode_for_model_provider,
7
10
  get_model_by_name,
8
11
  )
@@ -132,7 +135,7 @@ class TestDefaultStructuredOutputModeForModelProvider:
132
135
  ("llama_3_1_8b", ModelProviderName.groq, StructuredOutputMode.default),
133
136
  (
134
137
  "qwq_32b",
135
- ModelProviderName.fireworks_ai,
138
+ ModelProviderName.together_ai,
136
139
  StructuredOutputMode.json_instructions,
137
140
  ),
138
141
  ],
@@ -159,6 +162,186 @@ class TestDefaultStructuredOutputModeForModelProvider:
159
162
  assert result == first_provider.structured_output_mode
160
163
 
161
164
 
165
+ class TestBuiltInModelsFromProvider:
166
+ """Test cases for built_in_models_from_provider function"""
167
+
168
+ def test_valid_model_and_provider_returns_provider(self):
169
+ """Test that valid model and provider returns the correct provider configuration"""
170
+ # GPT 4.1 has OpenAI provider
171
+ result = built_in_models_from_provider(
172
+ provider_name=ModelProviderName.openai,
173
+ model_name="gpt_4_1",
174
+ )
175
+ assert result is not None
176
+ assert result.name == ModelProviderName.openai
177
+ assert result.model_id == "gpt-4.1"
178
+ assert result.supports_logprobs is True
179
+ assert result.suggested_for_data_gen is True
180
+
181
+ def test_valid_model_different_provider_returns_correct_provider(self):
182
+ """Test that different providers for the same model return different configurations"""
183
+ # GPT 4.1 has multiple providers with different configurations
184
+ openai_provider = built_in_models_from_provider(
185
+ provider_name=ModelProviderName.openai,
186
+ model_name="gpt_4_1",
187
+ )
188
+ openrouter_provider = built_in_models_from_provider(
189
+ provider_name=ModelProviderName.openrouter,
190
+ model_name="gpt_4_1",
191
+ )
192
+
193
+ assert openai_provider is not None
194
+ assert openrouter_provider is not None
195
+ assert openai_provider.name == ModelProviderName.openai
196
+ assert openrouter_provider.name == ModelProviderName.openrouter
197
+ assert openai_provider.model_id == "gpt-4.1"
198
+ assert openrouter_provider.model_id == "openai/gpt-4.1"
199
+
200
+ def test_invalid_model_name_returns_none(self):
201
+ """Test that invalid model name returns None"""
202
+ result = built_in_models_from_provider(
203
+ provider_name=ModelProviderName.openai,
204
+ model_name="invalid_model_name",
205
+ )
206
+ assert result is None
207
+
208
+ def test_valid_model_invalid_provider_returns_none(self):
209
+ """Test that valid model but invalid provider returns None"""
210
+ result = built_in_models_from_provider(
211
+ provider_name=ModelProviderName.gemini_api, # GPT 4.1 doesn't have gemini_api provider
212
+ model_name="gpt_4_1",
213
+ )
214
+ assert result is None
215
+
216
+ def test_model_with_single_provider(self):
217
+ """Test model that only has one provider"""
218
+ # Find a model with only one provider for this test
219
+ model = get_model_by_name(ModelName.gpt_4_1_nano)
220
+ assert len(model.providers) >= 1 # Verify it has providers
221
+
222
+ first_provider = model.providers[0]
223
+ result = built_in_models_from_provider(
224
+ provider_name=first_provider.name,
225
+ model_name="gpt_4_1_nano",
226
+ )
227
+ assert result is not None
228
+ assert result.name == first_provider.name
229
+ assert result.model_id == first_provider.model_id
230
+
231
+ def test_model_with_multiple_providers(self):
232
+ """Test model that has multiple providers"""
233
+ # GPT 4.1 has multiple providers
234
+ openai_result = built_in_models_from_provider(
235
+ provider_name=ModelProviderName.openai,
236
+ model_name="gpt_4_1",
237
+ )
238
+ openrouter_result = built_in_models_from_provider(
239
+ provider_name=ModelProviderName.openrouter,
240
+ model_name="gpt_4_1",
241
+ )
242
+ azure_result = built_in_models_from_provider(
243
+ provider_name=ModelProviderName.azure_openai,
244
+ model_name="gpt_4_1",
245
+ )
246
+
247
+ assert openai_result is not None
248
+ assert openrouter_result is not None
249
+ assert azure_result is not None
250
+ assert openai_result.name == ModelProviderName.openai
251
+ assert openrouter_result.name == ModelProviderName.openrouter
252
+ assert azure_result.name == ModelProviderName.azure_openai
253
+
254
+ def test_case_sensitive_model_name(self):
255
+ """Test that model name matching is case sensitive"""
256
+ # Should return None for case-mismatched model name
257
+ result = built_in_models_from_provider(
258
+ provider_name=ModelProviderName.openai,
259
+ model_name="GPT_4_1", # Wrong case
260
+ )
261
+ assert result is None
262
+
263
+ def test_provider_specific_attributes(self):
264
+ """Test that provider-specific attributes are correctly returned"""
265
+ # Test a provider with specific attributes like multimodal_capable
266
+ result = built_in_models_from_provider(
267
+ provider_name=ModelProviderName.openai,
268
+ model_name="gpt_4_1",
269
+ )
270
+ assert result is not None
271
+ assert result.multimodal_capable is True
272
+ assert result.supports_doc_extraction is True
273
+ assert result.multimodal_mime_types is not None
274
+ assert len(result.multimodal_mime_types) > 0
275
+
276
+ def test_provider_without_special_attributes(self):
277
+ """Test provider that doesn't have special attributes"""
278
+ # Test a simpler provider configuration
279
+ result = built_in_models_from_provider(
280
+ provider_name=ModelProviderName.openai,
281
+ model_name="gpt_4_1_nano",
282
+ )
283
+ assert result is not None
284
+ assert result.multimodal_capable is False # Should be False for nano
285
+ assert result.supports_doc_extraction is False
286
+
287
+ @pytest.mark.parametrize(
288
+ "model_name,provider,expected_model_id",
289
+ [
290
+ ("gpt_4o", ModelProviderName.openai, "gpt-4o"),
291
+ (
292
+ "claude_3_5_haiku",
293
+ ModelProviderName.anthropic,
294
+ "claude-3-5-haiku-20241022",
295
+ ),
296
+ ("gemini_2_5_pro", ModelProviderName.gemini_api, "gemini-2.5-pro"),
297
+ ("llama_3_1_8b", ModelProviderName.groq, "llama-3.1-8b-instant"),
298
+ ],
299
+ )
300
+ def test_parametrized_valid_combinations(
301
+ self, model_name, provider, expected_model_id
302
+ ):
303
+ """Test multiple valid model/provider combinations"""
304
+ result = built_in_models_from_provider(
305
+ provider_name=provider,
306
+ model_name=model_name,
307
+ )
308
+ assert result is not None
309
+ assert result.name == provider
310
+ assert result.model_id == expected_model_id
311
+
312
+ def test_empty_string_model_name(self):
313
+ """Test that empty string model name returns None"""
314
+ result = built_in_models_from_provider(
315
+ provider_name=ModelProviderName.openai,
316
+ model_name="",
317
+ )
318
+ assert result is None
319
+
320
+ def test_none_model_name(self):
321
+ """Test that None model name returns None"""
322
+ result = built_in_models_from_provider(
323
+ provider_name=ModelProviderName.openai,
324
+ model_name=None, # type: ignore
325
+ )
326
+ assert result is None
327
+
328
+ def test_all_built_in_models_have_valid_providers(self):
329
+ """Test that all built-in models have at least one valid provider"""
330
+ for model in built_in_models:
331
+ assert len(model.providers) > 0, f"Model {model.name} has no providers"
332
+ for provider in model.providers:
333
+ # Test that we can retrieve each provider
334
+ result = built_in_models_from_provider(
335
+ provider_name=provider.name,
336
+ model_name=model.name,
337
+ )
338
+ assert result is not None, (
339
+ f"Could not retrieve provider {provider.name} for model {model.name}"
340
+ )
341
+ assert result.name == provider.name
342
+ assert result.model_id == provider.model_id
343
+
344
+
162
345
  def test_uncensored():
163
346
  """Test that uncensored is set correctly"""
164
347
  model = get_model_by_name(ModelName.grok_3_mini)
@@ -177,6 +360,37 @@ def test_uncensored():
177
360
  assert provider.suggested_for_uncensored_data_gen
178
361
 
179
362
 
363
+ def test_multimodal_capable():
364
+ """Test that multimodal_capable is set correctly"""
365
+ model = get_model_by_name(ModelName.gpt_4_1)
366
+ for provider in model.providers:
367
+ assert provider.multimodal_capable
368
+ assert provider.supports_doc_extraction
369
+ assert provider.multimodal_mime_types is not None
370
+ assert len(provider.multimodal_mime_types) > 0
371
+
372
+
373
+ def test_no_empty_multimodal_mime_types():
374
+ """Ensure that multimodal fields are self-consistent as they are interdependent"""
375
+ for model in built_in_models:
376
+ for provider in model.providers:
377
+ # a multimodal model should always have mime types it supports
378
+ if provider.multimodal_capable:
379
+ assert provider.multimodal_mime_types is not None
380
+ assert len(provider.multimodal_mime_types) > 0
381
+
382
+ # a model that specifies mime types is necessarily multimodal
383
+ if (
384
+ provider.multimodal_mime_types is not None
385
+ and len(provider.multimodal_mime_types) > 0
386
+ ):
387
+ assert provider.multimodal_capable
388
+
389
+ # a model that supports doc extraction is necessarily multimodal
390
+ if provider.supports_doc_extraction:
391
+ assert provider.multimodal_capable
392
+
393
+
180
394
  def test_no_reasoning_for_structured_output():
181
395
  """Test that no reasoning is returned for structured output"""
182
396
  # get all models
@@ -186,3 +400,39 @@ def test_no_reasoning_for_structured_output():
186
400
  assert provider.reasoning_capable, (
187
401
  f"{model.name} {provider.name} has reasoning_optional_for_structured_output but is not reasoning capable. This field should only be defined for models that are reasoning capable."
188
402
  )
403
+
404
+
405
+ def test_unique_providers_per_model():
406
+ """Test that each model can only have one entry per provider"""
407
+ for model in built_in_models:
408
+ provider_names = [provider.name for provider in model.providers]
409
+ unique_provider_names = set(provider_names)
410
+
411
+ if len(provider_names) != len(unique_provider_names):
412
+ # Find which providers have duplicates
413
+ provider_counts = Counter(provider_names)
414
+ duplicates = {
415
+ name: count for name, count in provider_counts.items() if count > 1
416
+ }
417
+
418
+ # Show details about duplicates
419
+ duplicate_details = []
420
+ for provider_name, count in duplicates.items():
421
+ duplicate_providers = [
422
+ p for p in model.providers if p.name == provider_name
423
+ ]
424
+ model_ids = [p.model_id for p in duplicate_providers]
425
+ duplicate_details.append(
426
+ f"{provider_name} (appears {count} times with model_ids: {model_ids})"
427
+ )
428
+
429
+ assert False, (
430
+ f"Model {model.name} has duplicate providers:\n"
431
+ f"Expected: 1 entry per provider\n"
432
+ f"Found: {len(provider_names)} total entries, {len(unique_provider_names)} unique providers\n"
433
+ f"Duplicates: {', '.join(duplicate_details)}\n"
434
+ f"This suggests either:\n"
435
+ f"1. A bug where the same provider is accidentally duplicated, or\n"
436
+ f"2. Intentional design where the same provider offers different model variants\n"
437
+ f"If this is intentional, the test should be updated to allow multiple entries per provider."
438
+ )
@@ -1,13 +1,85 @@
1
1
  import json
2
+ from unittest.mock import patch
2
3
 
4
+ import pytest
5
+
6
+ from kiln_ai.adapters.ml_embedding_model_list import (
7
+ KilnEmbeddingModel,
8
+ KilnEmbeddingModelProvider,
9
+ )
10
+
11
+ # Mock data for testing - using proper Pydantic model instances
12
+ from kiln_ai.adapters.ml_model_list import KilnModel, KilnModelProvider
3
13
  from kiln_ai.adapters.ollama_tools import (
4
14
  OllamaConnection,
15
+ ollama_embedding_model_installed,
5
16
  ollama_model_installed,
6
17
  parse_ollama_tags,
7
18
  )
19
+ from kiln_ai.datamodel.datamodel_enums import ModelProviderName
8
20
 
21
+ MOCK_BUILT_IN_MODELS = [
22
+ KilnModel(
23
+ family="phi",
24
+ name="phi3.5",
25
+ friendly_name="phi3.5",
26
+ providers=[
27
+ KilnModelProvider(
28
+ name=ModelProviderName.ollama,
29
+ model_id="phi3.5",
30
+ ollama_model_aliases=None,
31
+ )
32
+ ],
33
+ ),
34
+ KilnModel(
35
+ family="gemma",
36
+ name="gemma2",
37
+ friendly_name="gemma2",
38
+ providers=[
39
+ KilnModelProvider(
40
+ name=ModelProviderName.ollama,
41
+ model_id="gemma2:2b",
42
+ ollama_model_aliases=None,
43
+ )
44
+ ],
45
+ ),
46
+ KilnModel(
47
+ family="llama",
48
+ name="llama3.1",
49
+ friendly_name="llama3.1",
50
+ providers=[
51
+ KilnModelProvider(
52
+ name=ModelProviderName.ollama,
53
+ model_id="llama3.1",
54
+ ollama_model_aliases=None,
55
+ )
56
+ ],
57
+ ),
58
+ ]
9
59
 
10
- def test_parse_ollama_tags_no_models():
60
+ MOCK_BUILT_IN_EMBEDDING_MODELS = [
61
+ KilnEmbeddingModel(
62
+ family="gemma",
63
+ name="embeddinggemma",
64
+ friendly_name="embeddinggemma",
65
+ providers=[
66
+ KilnEmbeddingModelProvider(
67
+ name=ModelProviderName.ollama,
68
+ model_id="embeddinggemma:300m",
69
+ n_dimensions=768,
70
+ ollama_model_aliases=["embeddinggemma"],
71
+ )
72
+ ],
73
+ ),
74
+ ]
75
+
76
+
77
+ @patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
78
+ @patch(
79
+ "kiln_ai.adapters.ollama_tools.built_in_embedding_models",
80
+ MOCK_BUILT_IN_EMBEDDING_MODELS,
81
+ )
82
+ def test_parse_ollama_tags_models():
11
83
  json_response = '{"models":[{"name":"scosman_net","model":"scosman_net:latest"},{"name":"phi3.5:latest","model":"phi3.5:latest","modified_at":"2024-10-02T12:04:35.191519822-04:00","size":2176178843,"digest":"61819fb370a3c1a9be6694869331e5f85f867a079e9271d66cb223acb81d04ba","details":{"parent_model":"","format":"gguf","family":"phi3","families":["phi3"],"parameter_size":"3.8B","quantization_level":"Q4_0"}},{"name":"gemma2:2b","model":"gemma2:2b","modified_at":"2024-09-09T16:46:38.64348929-04:00","size":1629518495,"digest":"8ccf136fdd5298f3ffe2d69862750ea7fb56555fa4d5b18c04e3fa4d82ee09d7","details":{"parent_model":"","format":"gguf","family":"gemma2","families":["gemma2"],"parameter_size":"2.6B","quantization_level":"Q4_0"}},{"name":"llama3.1:latest","model":"llama3.1:latest","modified_at":"2024-09-01T17:19:43.481523695-04:00","size":4661230720,"digest":"f66fc8dc39ea206e03ff6764fcc696b1b4dfb693f0b6ef751731dd4e6269046e","details":{"parent_model":"","format":"gguf","family":"llama","families":["llama"],"parameter_size":"8.0B","quantization_level":"Q4_0"}}]}'
12
84
  tags = json.loads(json_response)
13
85
  conn = parse_ollama_tags(tags)
@@ -16,7 +88,52 @@ def test_parse_ollama_tags_no_models():
16
88
  assert "llama3.1:latest" in conn.supported_models
17
89
  assert "scosman_net:latest" in conn.untested_models
18
90
 
91
+ # there should be no embedding models because the tags response does not include any embedding models
92
+ # that are in the built-in embedding models list
93
+ assert len(conn.supported_embedding_models) == 0
19
94
 
95
+
96
+ @patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
97
+ @patch(
98
+ "kiln_ai.adapters.ollama_tools.built_in_embedding_models",
99
+ MOCK_BUILT_IN_EMBEDDING_MODELS,
100
+ )
101
+ @pytest.mark.parametrize("json_response", ["{}", '{"models": []}'])
102
+ def test_parse_ollama_tags_no_models(json_response):
103
+ tags = json.loads(json_response)
104
+ conn = parse_ollama_tags(tags)
105
+ assert (
106
+ conn.message
107
+ == "Ollama is running, but no supported models are installed. Install one or more supported model, like 'ollama pull phi3.5'."
108
+ )
109
+ assert len(conn.supported_models) == 0
110
+ assert len(conn.untested_models) == 0
111
+ assert len(conn.supported_embedding_models) == 0
112
+
113
+
114
+ @patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
115
+ @patch(
116
+ "kiln_ai.adapters.ollama_tools.built_in_embedding_models",
117
+ MOCK_BUILT_IN_EMBEDDING_MODELS,
118
+ )
119
+ def test_parse_ollama_tags_empty_models():
120
+ """Test parsing Ollama tags response with empty models list"""
121
+ json_response = '{"models": []}'
122
+ tags = json.loads(json_response)
123
+ conn = parse_ollama_tags(tags)
124
+
125
+ # Check that connection indicates no supported models
126
+ assert conn.supported_models == []
127
+ assert conn.untested_models == []
128
+ assert conn.supported_embedding_models == []
129
+ assert "no supported models are installed" in conn.message
130
+
131
+
132
+ @patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
133
+ @patch(
134
+ "kiln_ai.adapters.ollama_tools.built_in_embedding_models",
135
+ MOCK_BUILT_IN_EMBEDDING_MODELS,
136
+ )
20
137
  def test_parse_ollama_tags_only_untested_models():
21
138
  json_response = '{"models":[{"name":"scosman_net","model":"scosman_net:latest"}]}'
22
139
  tags = json.loads(json_response)
@@ -24,12 +141,17 @@ def test_parse_ollama_tags_only_untested_models():
24
141
  assert conn.supported_models == []
25
142
  assert conn.untested_models == ["scosman_net:latest"]
26
143
 
144
+ # there should be no embedding models because the tags response does not include any embedding models
145
+ # that are in the built-in embedding models list
146
+ assert len(conn.supported_embedding_models) == 0
147
+
27
148
 
28
149
  def test_ollama_model_installed():
29
150
  conn = OllamaConnection(
30
151
  supported_models=["phi3.5:latest", "gemma2:2b", "llama3.1:latest"],
31
152
  message="Connected",
32
153
  untested_models=["scosman_net:latest"],
154
+ supported_embedding_models=["embeddinggemma:300m"],
33
155
  )
34
156
  assert ollama_model_installed(conn, "phi3.5:latest")
35
157
  assert ollama_model_installed(conn, "phi3.5")
@@ -39,3 +161,220 @@ def test_ollama_model_installed():
39
161
  assert ollama_model_installed(conn, "scosman_net:latest")
40
162
  assert ollama_model_installed(conn, "scosman_net")
41
163
  assert not ollama_model_installed(conn, "unknown_model")
164
+
165
+ # use the ollama_embedding_model_installed for testing embedding models installed, not ollama_model_installed
166
+ assert not ollama_model_installed(conn, "embeddinggemma:300m")
167
+ assert not ollama_model_installed(conn, "embeddinggemma")
168
+
169
+
170
+ def test_ollama_model_installed_embedding_models():
171
+ conn = OllamaConnection(
172
+ message="Connected",
173
+ supported_models=["phi3.5:latest", "gemma2:2b", "llama3.1:latest"],
174
+ untested_models=["scosman_net:latest"],
175
+ supported_embedding_models=["embeddinggemma:300m", "embeddinggemma:latest"],
176
+ )
177
+
178
+ assert ollama_embedding_model_installed(conn, "embeddinggemma:300m")
179
+ assert ollama_embedding_model_installed(conn, "embeddinggemma:latest")
180
+ assert not ollama_embedding_model_installed(conn, "unknown_embedding")
181
+
182
+ # use the ollama_model_installed for testing regular models installed, not ollama_embedding_model_installed
183
+ assert not ollama_embedding_model_installed(conn, "phi3.5:latest")
184
+ assert not ollama_embedding_model_installed(conn, "gemma2:2b")
185
+ assert not ollama_embedding_model_installed(conn, "llama3.1:latest")
186
+ assert not ollama_embedding_model_installed(conn, "scosman_net:latest")
187
+
188
+
189
+ @patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
190
+ @patch(
191
+ "kiln_ai.adapters.ollama_tools.built_in_embedding_models",
192
+ MOCK_BUILT_IN_EMBEDDING_MODELS,
193
+ )
194
+ def test_parse_ollama_tags_with_embedding_models():
195
+ """Test parsing Ollama tags response that includes embedding models"""
196
+ json_response = """{
197
+ "models": [
198
+ {
199
+ "name": "phi3.5:latest",
200
+ "model": "phi3.5:latest"
201
+ },
202
+ {
203
+ "name": "embeddinggemma:300m",
204
+ "model": "embeddinggemma:300m"
205
+ },
206
+ {
207
+ "name": "embeddinggemma:latest",
208
+ "model": "embeddinggemma:latest"
209
+ },
210
+ {
211
+ "name": "unknown_embedding:latest",
212
+ "model": "unknown_embedding:latest"
213
+ }
214
+ ]
215
+ }"""
216
+ tags = json.loads(json_response)
217
+ conn = parse_ollama_tags(tags)
218
+
219
+ # Check that embedding models are properly categorized
220
+ assert "embeddinggemma:300m" in conn.supported_embedding_models
221
+ assert "embeddinggemma:latest" in conn.supported_embedding_models
222
+
223
+ # Check that regular models are still parsed correctly
224
+ assert "phi3.5:latest" in conn.supported_models
225
+
226
+ # Check that embedding models are NOT in the main model lists
227
+ assert "embeddinggemma:300m" not in conn.supported_models
228
+ assert "embeddinggemma:latest" not in conn.supported_models
229
+ assert "embeddinggemma:300m" not in conn.untested_models
230
+ assert "embeddinggemma:latest" not in conn.untested_models
231
+
232
+ # we assume the unknown models are normal models, not embedding models - because
233
+ # we don't support untested embedding models currently
234
+ assert "unknown_embedding:latest" not in conn.supported_embedding_models
235
+ assert "unknown_embedding:latest" in conn.untested_models
236
+
237
+
238
+ @patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
239
+ @patch(
240
+ "kiln_ai.adapters.ollama_tools.built_in_embedding_models",
241
+ MOCK_BUILT_IN_EMBEDDING_MODELS,
242
+ )
243
+ def test_parse_ollama_tags_embedding_model_aliases():
244
+ """Test parsing Ollama tags response with embedding model aliases"""
245
+ json_response = """{
246
+ "models": [
247
+ {
248
+ "name": "embeddinggemma",
249
+ "model": "embeddinggemma"
250
+ }
251
+ ]
252
+ }"""
253
+ tags = json.loads(json_response)
254
+ conn = parse_ollama_tags(tags)
255
+
256
+ # Check that embedding model aliases are recognized
257
+ assert "embeddinggemma" in conn.supported_embedding_models
258
+
259
+ # Check that embedding model aliases are NOT in the main model lists
260
+ assert "embeddinggemma" not in conn.supported_models
261
+ assert "embeddinggemma" not in conn.untested_models
262
+
263
+ assert len(conn.supported_models) == 0
264
+ assert len(conn.untested_models) == 0
265
+ assert len(conn.supported_embedding_models) == 1
266
+
267
+
268
+ @patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
269
+ @patch(
270
+ "kiln_ai.adapters.ollama_tools.built_in_embedding_models",
271
+ MOCK_BUILT_IN_EMBEDDING_MODELS,
272
+ )
273
+ def test_parse_ollama_tags_only_embedding_models():
274
+ """Test parsing Ollama tags response with only embedding models"""
275
+ json_response = """{
276
+ "models": [
277
+ {
278
+ "name": "embeddinggemma:300m",
279
+ "model": "embeddinggemma:300m"
280
+ }
281
+ ]
282
+ }"""
283
+ tags = json.loads(json_response)
284
+ conn = parse_ollama_tags(tags)
285
+
286
+ # Check that embedding models are found but no regular models
287
+ assert "embeddinggemma:300m" in conn.supported_embedding_models
288
+ assert conn.supported_models == []
289
+ assert conn.untested_models == []
290
+
291
+ # Check that embedding models are NOT in the main model lists
292
+ assert "embeddinggemma:300m" not in conn.supported_models
293
+ assert "embeddinggemma:300m" not in conn.untested_models
294
+
295
+
296
+ def test_ollama_connection_all_embedding_models():
297
+ """Test OllamaConnection.all_embedding_models() method"""
298
+ conn = OllamaConnection(
299
+ message="Connected",
300
+ supported_models=["phi3.5:latest"],
301
+ untested_models=["unknown:latest"],
302
+ supported_embedding_models=["embeddinggemma:300m", "embeddinggemma:latest"],
303
+ )
304
+
305
+ embedding_models = conn.all_embedding_models()
306
+ assert embedding_models == ["embeddinggemma:300m", "embeddinggemma:latest"]
307
+
308
+
309
+ def test_ollama_connection_empty_embedding_models():
310
+ """Test OllamaConnection.all_embedding_models() with empty list"""
311
+ conn = OllamaConnection(
312
+ message="Connected",
313
+ supported_models=["phi3.5:latest"],
314
+ untested_models=["unknown:latest"],
315
+ supported_embedding_models=[],
316
+ )
317
+
318
+ embedding_models = conn.all_embedding_models()
319
+ assert embedding_models == []
320
+
321
+
322
+ @patch("kiln_ai.adapters.ollama_tools.built_in_models", MOCK_BUILT_IN_MODELS)
323
+ @patch(
324
+ "kiln_ai.adapters.ollama_tools.built_in_embedding_models",
325
+ MOCK_BUILT_IN_EMBEDDING_MODELS,
326
+ )
327
+ def test_parse_ollama_tags_mixed_models_and_embeddings():
328
+ """Test parsing response with mix of regular models, embedding models, and unknown models"""
329
+ json_response = """{
330
+ "models": [
331
+ {
332
+ "name": "phi3.5:latest",
333
+ "model": "phi3.5:latest"
334
+ },
335
+ {
336
+ "name": "gemma2:2b",
337
+ "model": "gemma2:2b"
338
+ },
339
+ {
340
+ "name": "embeddinggemma:300m",
341
+ "model": "embeddinggemma:300m"
342
+ },
343
+ {
344
+ "name": "embeddinggemma",
345
+ "model": "embeddinggemma"
346
+ },
347
+ {
348
+ "name": "unknown_model:latest",
349
+ "model": "unknown_model:latest"
350
+ },
351
+ {
352
+ "name": "unknown_embedding:latest",
353
+ "model": "unknown_embedding:latest"
354
+ }
355
+ ]
356
+ }"""
357
+ tags = json.loads(json_response)
358
+ conn = parse_ollama_tags(tags)
359
+
360
+ # Check regular models
361
+ assert "phi3.5:latest" in conn.supported_models
362
+ assert "gemma2:2b" in conn.supported_models
363
+ assert "unknown_model:latest" in conn.untested_models
364
+
365
+ # Check embedding models
366
+ assert "embeddinggemma:300m" in conn.supported_embedding_models
367
+ assert "embeddinggemma" in conn.supported_embedding_models
368
+
369
+ # Check that embedding models are NOT in the main model lists
370
+ assert "embeddinggemma:300m" not in conn.supported_models
371
+ assert "embeddinggemma" not in conn.supported_models
372
+ assert "embeddinggemma:300m" not in conn.untested_models
373
+ assert "embeddinggemma" not in conn.untested_models
374
+
375
+ # Unknown embedding models should not appear in supported_embedding_models
376
+ assert "unknown_embedding:latest" not in conn.supported_embedding_models
377
+
378
+ # Unknown embedding models should appear in untested_models (since they're not recognized as embeddings)
379
+ assert "unknown_embedding:latest" not in conn.supported_models
380
+ assert "unknown_embedding:latest" in conn.untested_models