kiln-ai 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (158) hide show
  1. kiln_ai/adapters/__init__.py +8 -2
  2. kiln_ai/adapters/adapter_registry.py +43 -208
  3. kiln_ai/adapters/chat/chat_formatter.py +8 -12
  4. kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
  5. kiln_ai/adapters/chunkers/__init__.py +13 -0
  6. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  7. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  8. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  9. kiln_ai/adapters/chunkers/helpers.py +23 -0
  10. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  11. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  12. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  13. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  14. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  15. kiln_ai/adapters/docker_model_runner_tools.py +119 -0
  16. kiln_ai/adapters/embedding/__init__.py +0 -0
  17. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  18. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  19. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  20. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  21. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  22. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  23. kiln_ai/adapters/eval/base_eval.py +2 -2
  24. kiln_ai/adapters/eval/eval_runner.py +9 -3
  25. kiln_ai/adapters/eval/g_eval.py +2 -2
  26. kiln_ai/adapters/eval/test_base_eval.py +2 -4
  27. kiln_ai/adapters/eval/test_g_eval.py +4 -5
  28. kiln_ai/adapters/extractors/__init__.py +18 -0
  29. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  30. kiln_ai/adapters/extractors/encoding.py +20 -0
  31. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  32. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  33. kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
  34. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  35. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  36. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  37. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  38. kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
  39. kiln_ai/adapters/fine_tune/__init__.py +1 -1
  40. kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
  41. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  42. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  43. kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
  44. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  45. kiln_ai/adapters/ml_embedding_model_list.py +192 -0
  46. kiln_ai/adapters/ml_model_list.py +761 -37
  47. kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
  48. kiln_ai/adapters/model_adapters/litellm_adapter.py +380 -138
  49. kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
  50. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -2
  51. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
  52. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
  53. kiln_ai/adapters/model_adapters/test_structured_output.py +113 -5
  54. kiln_ai/adapters/ollama_tools.py +69 -12
  55. kiln_ai/adapters/parsers/__init__.py +1 -1
  56. kiln_ai/adapters/provider_tools.py +205 -47
  57. kiln_ai/adapters/rag/deduplication.py +49 -0
  58. kiln_ai/adapters/rag/progress.py +252 -0
  59. kiln_ai/adapters/rag/rag_runners.py +844 -0
  60. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  61. kiln_ai/adapters/rag/test_progress.py +785 -0
  62. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  63. kiln_ai/adapters/remote_config.py +80 -8
  64. kiln_ai/adapters/repair/test_repair_task.py +12 -9
  65. kiln_ai/adapters/run_output.py +3 -0
  66. kiln_ai/adapters/test_adapter_registry.py +657 -85
  67. kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
  68. kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
  69. kiln_ai/adapters/test_ml_model_list.py +251 -1
  70. kiln_ai/adapters/test_ollama_tools.py +340 -1
  71. kiln_ai/adapters/test_prompt_adaptors.py +13 -6
  72. kiln_ai/adapters/test_prompt_builders.py +1 -1
  73. kiln_ai/adapters/test_provider_tools.py +254 -8
  74. kiln_ai/adapters/test_remote_config.py +651 -58
  75. kiln_ai/adapters/vector_store/__init__.py +1 -0
  76. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  77. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  78. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  79. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  80. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  81. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  82. kiln_ai/datamodel/__init__.py +39 -34
  83. kiln_ai/datamodel/basemodel.py +170 -1
  84. kiln_ai/datamodel/chunk.py +158 -0
  85. kiln_ai/datamodel/datamodel_enums.py +28 -0
  86. kiln_ai/datamodel/embedding.py +64 -0
  87. kiln_ai/datamodel/eval.py +1 -1
  88. kiln_ai/datamodel/external_tool_server.py +298 -0
  89. kiln_ai/datamodel/extraction.py +303 -0
  90. kiln_ai/datamodel/json_schema.py +25 -10
  91. kiln_ai/datamodel/project.py +40 -1
  92. kiln_ai/datamodel/rag.py +79 -0
  93. kiln_ai/datamodel/registry.py +0 -15
  94. kiln_ai/datamodel/run_config.py +62 -0
  95. kiln_ai/datamodel/task.py +2 -77
  96. kiln_ai/datamodel/task_output.py +6 -1
  97. kiln_ai/datamodel/task_run.py +41 -0
  98. kiln_ai/datamodel/test_attachment.py +649 -0
  99. kiln_ai/datamodel/test_basemodel.py +4 -4
  100. kiln_ai/datamodel/test_chunk_models.py +317 -0
  101. kiln_ai/datamodel/test_dataset_split.py +1 -1
  102. kiln_ai/datamodel/test_embedding_models.py +448 -0
  103. kiln_ai/datamodel/test_eval_model.py +6 -6
  104. kiln_ai/datamodel/test_example_models.py +175 -0
  105. kiln_ai/datamodel/test_external_tool_server.py +691 -0
  106. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  107. kiln_ai/datamodel/test_extraction_model.py +470 -0
  108. kiln_ai/datamodel/test_rag.py +641 -0
  109. kiln_ai/datamodel/test_registry.py +8 -3
  110. kiln_ai/datamodel/test_task.py +15 -47
  111. kiln_ai/datamodel/test_tool_id.py +320 -0
  112. kiln_ai/datamodel/test_vector_store.py +320 -0
  113. kiln_ai/datamodel/tool_id.py +105 -0
  114. kiln_ai/datamodel/vector_store.py +141 -0
  115. kiln_ai/tools/__init__.py +8 -0
  116. kiln_ai/tools/base_tool.py +82 -0
  117. kiln_ai/tools/built_in_tools/__init__.py +13 -0
  118. kiln_ai/tools/built_in_tools/math_tools.py +124 -0
  119. kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
  120. kiln_ai/tools/mcp_server_tool.py +95 -0
  121. kiln_ai/tools/mcp_session_manager.py +246 -0
  122. kiln_ai/tools/rag_tools.py +157 -0
  123. kiln_ai/tools/test_base_tools.py +199 -0
  124. kiln_ai/tools/test_mcp_server_tool.py +457 -0
  125. kiln_ai/tools/test_mcp_session_manager.py +1585 -0
  126. kiln_ai/tools/test_rag_tools.py +848 -0
  127. kiln_ai/tools/test_tool_registry.py +562 -0
  128. kiln_ai/tools/tool_registry.py +85 -0
  129. kiln_ai/utils/__init__.py +3 -0
  130. kiln_ai/utils/async_job_runner.py +62 -17
  131. kiln_ai/utils/config.py +24 -2
  132. kiln_ai/utils/env.py +15 -0
  133. kiln_ai/utils/filesystem.py +14 -0
  134. kiln_ai/utils/filesystem_cache.py +60 -0
  135. kiln_ai/utils/litellm.py +94 -0
  136. kiln_ai/utils/lock.py +100 -0
  137. kiln_ai/utils/mime_type.py +38 -0
  138. kiln_ai/utils/open_ai_types.py +94 -0
  139. kiln_ai/utils/pdf_utils.py +38 -0
  140. kiln_ai/utils/project_utils.py +17 -0
  141. kiln_ai/utils/test_async_job_runner.py +151 -35
  142. kiln_ai/utils/test_config.py +138 -1
  143. kiln_ai/utils/test_env.py +142 -0
  144. kiln_ai/utils/test_filesystem_cache.py +316 -0
  145. kiln_ai/utils/test_litellm.py +206 -0
  146. kiln_ai/utils/test_lock.py +185 -0
  147. kiln_ai/utils/test_mime_type.py +66 -0
  148. kiln_ai/utils/test_open_ai_types.py +131 -0
  149. kiln_ai/utils/test_pdf_utils.py +73 -0
  150. kiln_ai/utils/test_uuid.py +111 -0
  151. kiln_ai/utils/test_validation.py +524 -0
  152. kiln_ai/utils/uuid.py +9 -0
  153. kiln_ai/utils/validation.py +90 -0
  154. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +12 -5
  155. kiln_ai-0.21.0.dist-info/RECORD +211 -0
  156. kiln_ai-0.19.0.dist-info/RECORD +0 -115
  157. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
  158. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -13,7 +13,7 @@ from kiln_ai.datamodel import (
13
13
  Task,
14
14
  Usage,
15
15
  )
16
- from kiln_ai.datamodel.task import RunConfig
16
+ from kiln_ai.datamodel.task import RunConfigProperties
17
17
  from kiln_ai.utils.config import Config
18
18
 
19
19
 
@@ -41,8 +41,8 @@ def test_task(tmp_path):
41
41
  @pytest.fixture
42
42
  def adapter(test_task):
43
43
  return MockAdapter(
44
- run_config=RunConfig(
45
- task=test_task,
44
+ task=test_task,
45
+ run_config=RunConfigProperties(
46
46
  model_name="phi_3_5",
47
47
  model_provider_name="ollama",
48
48
  prompt_id="simple_chain_of_thought_prompt_builder",
@@ -240,8 +240,8 @@ async def test_autosave_true(test_task, adapter):
240
240
  def test_properties_for_task_output_custom_values(test_task):
241
241
  """Test that _properties_for_task_output includes custom temperature, top_p, and structured_output_mode"""
242
242
  adapter = MockAdapter(
243
- run_config=RunConfig(
244
- task=test_task,
243
+ task=test_task,
244
+ run_config=RunConfigProperties(
245
245
  model_name="gpt-4",
246
246
  model_provider_name="openai",
247
247
  prompt_id="simple_prompt_builder",
@@ -1,8 +1,10 @@
1
1
  import json
2
2
  from pathlib import Path
3
3
  from typing import Dict
4
+ from unittest.mock import Mock, patch
4
5
 
5
6
  import pytest
7
+ from litellm.types.utils import ModelResponse
6
8
 
7
9
  import kiln_ai.datamodel as datamodel
8
10
  from kiln_ai.adapters.adapter_registry import adapter_for_task
@@ -11,7 +13,7 @@ from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput,
11
13
  from kiln_ai.adapters.ollama_tools import ollama_online
12
14
  from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
13
15
  from kiln_ai.datamodel import PromptId
14
- from kiln_ai.datamodel.task import RunConfig, RunConfigProperties
16
+ from kiln_ai.datamodel.task import RunConfigProperties
15
17
  from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
16
18
 
17
19
 
@@ -40,8 +42,8 @@ async def test_structured_output_ollama(tmp_path, model_name):
40
42
  class MockAdapter(BaseAdapter):
41
43
  def __init__(self, kiln_task: datamodel.Task, response: Dict | str | None):
42
44
  super().__init__(
43
- run_config=RunConfig(
44
- task=kiln_task,
45
+ task=kiln_task,
46
+ run_config=RunConfigProperties(
45
47
  model_name="phi_3_5",
46
48
  model_provider_name="ollama",
47
49
  prompt_id="simple_chain_of_thought_prompt_builder",
@@ -259,6 +261,7 @@ async def run_structured_input_task(
259
261
  model_name: str,
260
262
  provider: str,
261
263
  prompt_id: PromptId,
264
+ verify_trace_cot: bool = False,
262
265
  ):
263
266
  response, a, run = await run_structured_input_task_no_validation(
264
267
  task, model_name, provider, prompt_id
@@ -282,6 +285,32 @@ async def run_structured_input_task(
282
285
  assert "reasoning" in run.intermediate_outputs
283
286
  assert isinstance(run.intermediate_outputs["reasoning"], str)
284
287
 
288
+ # Check the trace
289
+ trace = run.trace
290
+ assert trace is not None
291
+ if verify_trace_cot:
292
+ assert len(trace) == 5
293
+ assert trace[0]["role"] == "system"
294
+ assert "You are an assistant which classifies a triangle" in trace[0]["content"]
295
+ assert trace[1]["role"] == "user"
296
+ assert trace[2]["role"] == "assistant"
297
+ assert trace[2].get("tool_calls") is None
298
+ assert trace[3]["role"] == "user"
299
+ assert trace[4]["role"] == "assistant"
300
+ assert trace[4].get("tool_calls") is None
301
+ else:
302
+ assert len(trace) == 3
303
+ assert trace[0]["role"] == "system"
304
+ assert "You are an assistant which classifies a triangle" in trace[0]["content"]
305
+ assert trace[1]["role"] == "user"
306
+ json_content = json.loads(trace[1]["content"])
307
+ assert json_content["a"] == 2
308
+ assert json_content["b"] == 2
309
+ assert json_content["c"] == 2
310
+ assert trace[2]["role"] == "assistant"
311
+ assert trace[2].get("tool_calls") is None
312
+ assert "[[equilateral]]" in trace[2]["content"]
313
+
285
314
 
286
315
  @pytest.mark.paid
287
316
  async def test_structured_input_gpt_4o_mini(tmp_path):
@@ -299,15 +328,94 @@ async def test_all_built_in_models_structured_input(
299
328
  )
300
329
 
301
330
 
331
+ async def test_all_built_in_models_structured_input_mocked(tmp_path):
332
+ mock_response = ModelResponse(
333
+ model="gpt-4o-mini",
334
+ choices=[
335
+ {
336
+ "message": {
337
+ "content": "The answer is [[equilateral]]",
338
+ }
339
+ }
340
+ ],
341
+ )
342
+
343
+ # Mock the Config.shared() method to return a mock config with required attributes
344
+ mock_config = Mock()
345
+ mock_config.open_ai_api_key = "mock_api_key"
346
+ mock_config.user_id = "test_user"
347
+ mock_config.groq_api_key = "mock_api_key"
348
+
349
+ with (
350
+ patch(
351
+ "litellm.acompletion",
352
+ side_effect=[mock_response],
353
+ ),
354
+ patch("kiln_ai.utils.config.Config.shared", return_value=mock_config),
355
+ ):
356
+ await run_structured_input_test(
357
+ tmp_path, "llama_3_1_8b", "groq", "simple_prompt_builder"
358
+ )
359
+
360
+
302
361
  @pytest.mark.paid
303
362
  @pytest.mark.ollama
304
363
  @pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
305
364
  async def test_structured_input_cot_prompt_builder(tmp_path, model_name, provider_name):
306
365
  task = build_structured_input_test_task(tmp_path)
307
366
  await run_structured_input_task(
308
- task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
367
+ task,
368
+ model_name,
369
+ provider_name,
370
+ "simple_chain_of_thought_prompt_builder",
371
+ verify_trace_cot=True,
372
+ )
373
+
374
+
375
+ async def test_structured_input_cot_prompt_builder_mocked(tmp_path):
376
+ task = build_structured_input_test_task(tmp_path)
377
+ mock_response_1 = ModelResponse(
378
+ model="gpt-4o-mini",
379
+ choices=[
380
+ {
381
+ "message": {
382
+ "content": "I'm thinking real hard... oh!",
383
+ }
384
+ }
385
+ ],
386
+ )
387
+ mock_response_2 = ModelResponse(
388
+ model="gpt-4o-mini",
389
+ choices=[
390
+ {
391
+ "message": {
392
+ "content": "After thinking, I've decided the answer is [[equilateral]]",
393
+ }
394
+ }
395
+ ],
309
396
  )
310
397
 
398
+ # Mock the Config.shared() method to return a mock config with required attributes
399
+ mock_config = Mock()
400
+ mock_config.open_ai_api_key = "mock_api_key"
401
+ mock_config.user_id = "test_user"
402
+ mock_config.groq_api_key = "mock_api_key"
403
+
404
+ with (
405
+ patch(
406
+ "litellm.acompletion",
407
+ side_effect=[mock_response_1, mock_response_2],
408
+ ),
409
+ patch("kiln_ai.utils.config.Config.shared", return_value=mock_config),
410
+ ):
411
+ await run_structured_input_task(
412
+ task,
413
+ "llama_3_1_8b",
414
+ "groq",
415
+ "simple_chain_of_thought_prompt_builder",
416
+ verify_trace_cot=True,
417
+ )
418
+
311
419
 
312
420
  @pytest.mark.paid
313
421
  @pytest.mark.ollama
@@ -350,7 +458,7 @@ When asked for a final result, this is the format (for an equilateral example):
350
458
  """
351
459
  task.output_json_schema = json.dumps(triangle_schema)
352
460
  task.save_to_file()
353
- response, adapter, _ = await run_structured_input_task_no_validation(
461
+ response, _, _ = await run_structured_input_task_no_validation(
354
462
  task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
355
463
  )
356
464
 
@@ -4,6 +4,7 @@ import httpx
4
4
  import requests
5
5
  from pydantic import BaseModel, Field
6
6
 
7
+ from kiln_ai.adapters.ml_embedding_model_list import built_in_embedding_models
7
8
  from kiln_ai.adapters.ml_model_list import ModelProviderName, built_in_models
8
9
  from kiln_ai.utils.config import Config
9
10
 
@@ -41,22 +42,28 @@ class OllamaConnection(BaseModel):
41
42
  version: str | None = None
42
43
  supported_models: List[str]
43
44
  untested_models: List[str] = Field(default_factory=list)
45
+ supported_embedding_models: List[str] = Field(default_factory=list)
44
46
 
45
47
  def all_models(self) -> List[str]:
46
48
  return self.supported_models + self.untested_models
47
49
 
50
+ def all_embedding_models(self) -> List[str]:
51
+ return self.supported_embedding_models
52
+
48
53
 
49
54
  # Parse the Ollama /api/tags response
50
- def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
55
+ def parse_ollama_tags(tags: Any) -> OllamaConnection:
51
56
  # Build a list of models we support for Ollama from the built-in model list
52
- supported_ollama_models = [
53
- provider.model_id
54
- for model in built_in_models
55
- for provider in model.providers
56
- if provider.name == ModelProviderName.ollama
57
- ]
57
+ supported_ollama_models = set(
58
+ [
59
+ provider.model_id
60
+ for model in built_in_models
61
+ for provider in model.providers
62
+ if provider.name == ModelProviderName.ollama
63
+ ]
64
+ )
58
65
  # Append model_aliases to supported_ollama_models
59
- supported_ollama_models.extend(
66
+ supported_ollama_models.update(
60
67
  [
61
68
  alias
62
69
  for model in built_in_models
@@ -65,16 +72,44 @@ def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
65
72
  ]
66
73
  )
67
74
 
75
+ supported_ollama_embedding_models = set(
76
+ [
77
+ provider.model_id
78
+ for model in built_in_embedding_models
79
+ for provider in model.providers
80
+ if provider.name == ModelProviderName.ollama
81
+ ]
82
+ )
83
+ supported_ollama_embedding_models.update(
84
+ [
85
+ alias
86
+ for model in built_in_embedding_models
87
+ for provider in model.providers
88
+ for alias in provider.ollama_model_aliases or []
89
+ ]
90
+ )
91
+
68
92
  if "models" in tags:
69
93
  models = tags["models"]
70
94
  if isinstance(models, list):
71
95
  model_names = [model["model"] for model in models]
72
96
  available_supported_models = []
73
97
  untested_models = []
74
- supported_models_latest_aliases = [
75
- f"{m}:latest" for m in supported_ollama_models
76
- ]
98
+ supported_models_latest_aliases = set(
99
+ [f"{m}:latest" for m in supported_ollama_models]
100
+ )
101
+ supported_embedding_models_latest_aliases = set(
102
+ [f"{m}:latest" for m in supported_ollama_embedding_models]
103
+ )
104
+
77
105
  for model in model_names:
106
+ # Skip embedding models - they should only appear in supported_embedding_models
107
+ if (
108
+ model in supported_ollama_embedding_models
109
+ or model in supported_embedding_models_latest_aliases
110
+ ):
111
+ continue
112
+
78
113
  if (
79
114
  model in supported_ollama_models
80
115
  or model in supported_models_latest_aliases
@@ -83,17 +118,31 @@ def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
83
118
  else:
84
119
  untested_models.append(model)
85
120
 
86
- if available_supported_models or untested_models:
121
+ available_supported_embedding_models = []
122
+ for model in model_names:
123
+ if (
124
+ model in supported_ollama_embedding_models
125
+ or model in supported_embedding_models_latest_aliases
126
+ ):
127
+ available_supported_embedding_models.append(model)
128
+
129
+ if (
130
+ available_supported_models
131
+ or untested_models
132
+ or available_supported_embedding_models
133
+ ):
87
134
  return OllamaConnection(
88
135
  message="Ollama connected",
89
136
  supported_models=available_supported_models,
90
137
  untested_models=untested_models,
138
+ supported_embedding_models=available_supported_embedding_models,
91
139
  )
92
140
 
93
141
  return OllamaConnection(
94
142
  message="Ollama is running, but no supported models are installed. Install one or more supported model, like 'ollama pull phi3.5'.",
95
143
  supported_models=[],
96
144
  untested_models=[],
145
+ supported_embedding_models=[],
97
146
  )
98
147
 
99
148
 
@@ -113,3 +162,11 @@ async def get_ollama_connection() -> OllamaConnection | None:
113
162
  def ollama_model_installed(conn: OllamaConnection, model_name: str) -> bool:
114
163
  all_models = conn.all_models()
115
164
  return model_name in all_models or f"{model_name}:latest" in all_models
165
+
166
+
167
+ def ollama_embedding_model_installed(conn: OllamaConnection, model_name: str) -> bool:
168
+ all_embedding_models = conn.all_embedding_models()
169
+ return (
170
+ model_name in all_embedding_models
171
+ or f"{model_name}:latest" in all_embedding_models
172
+ )
@@ -7,4 +7,4 @@ Parsing utilities for JSON and models with custom output formats (R1, etc.)
7
7
 
8
8
  from . import base_parser, json_parser, r1_parser
9
9
 
10
- __all__ = ["r1_parser", "base_parser", "json_parser"]
10
+ __all__ = ["base_parser", "json_parser", "r1_parser"]
@@ -1,7 +1,13 @@
1
1
  import logging
2
+ import os
2
3
  from dataclasses import dataclass
3
4
  from typing import Dict, List
4
5
 
6
+ from pydantic import BaseModel
7
+
8
+ from kiln_ai.adapters.docker_model_runner_tools import (
9
+ get_docker_model_runner_connection,
10
+ )
5
11
  from kiln_ai.adapters.ml_model_list import (
6
12
  KilnModel,
7
13
  KilnModelProvider,
@@ -10,14 +16,12 @@ from kiln_ai.adapters.ml_model_list import (
10
16
  StructuredOutputMode,
11
17
  built_in_models,
12
18
  )
13
- from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
14
19
  from kiln_ai.adapters.ollama_tools import get_ollama_connection
15
20
  from kiln_ai.datamodel import Finetune, Task
16
21
  from kiln_ai.datamodel.datamodel_enums import ChatStrategy
17
- from kiln_ai.datamodel.registry import project_from_id
18
- from kiln_ai.datamodel.task import RunConfigProperties
19
22
  from kiln_ai.utils.config import Config
20
23
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
24
+ from kiln_ai.utils.project_utils import project_from_id
21
25
 
22
26
  logger = logging.getLogger(__name__)
23
27
 
@@ -32,6 +36,15 @@ async def provider_enabled(provider_name: ModelProviderName) -> bool:
32
36
  except Exception:
33
37
  return False
34
38
 
39
+ if provider_name == ModelProviderName.docker_model_runner:
40
+ try:
41
+ conn = await get_docker_model_runner_connection()
42
+ return conn is not None and (
43
+ len(conn.supported_models) > 0 or len(conn.untested_models) > 0
44
+ )
45
+ except Exception:
46
+ return False
47
+
35
48
  provider_warning = provider_warnings.get(provider_name)
36
49
  if provider_warning is None:
37
50
  return False
@@ -180,50 +193,6 @@ def kiln_model_provider_from(
180
193
  )
181
194
 
182
195
 
183
- def lite_llm_config_for_openai_compatible(
184
- run_config_properties: RunConfigProperties,
185
- ) -> LiteLlmConfig:
186
- model_id = run_config_properties.model_name
187
- try:
188
- openai_provider_name, model_id = model_id.split("::")
189
- except Exception:
190
- raise ValueError(f"Invalid openai compatible model ID: {model_id}")
191
-
192
- openai_compatible_providers = Config.shared().openai_compatible_providers or []
193
- provider = next(
194
- filter(
195
- lambda p: p.get("name") == openai_provider_name, openai_compatible_providers
196
- ),
197
- None,
198
- )
199
- if provider is None:
200
- raise ValueError(f"OpenAI compatible provider {openai_provider_name} not found")
201
-
202
- # API key optional - some providers like Ollama don't use it, but LiteLLM errors without one
203
- api_key = provider.get("api_key") or "NA"
204
- base_url = provider.get("base_url")
205
- if base_url is None:
206
- raise ValueError(
207
- f"OpenAI compatible provider {openai_provider_name} has no base URL"
208
- )
209
-
210
- # Update a copy of the run config properties to use the openai compatible provider
211
- updated_run_config_properties = run_config_properties.model_copy(deep=True)
212
- updated_run_config_properties.model_provider_name = (
213
- ModelProviderName.openai_compatible
214
- )
215
- updated_run_config_properties.model_name = model_id
216
-
217
- return LiteLlmConfig(
218
- # OpenAI compatible, with a custom base URL
219
- run_config_properties=updated_run_config_properties,
220
- base_url=base_url,
221
- additional_body_options={
222
- "api_key": api_key,
223
- },
224
- )
225
-
226
-
227
196
  def lite_llm_provider_model(
228
197
  model_id: str,
229
198
  ) -> KilnModelProvider:
@@ -377,6 +346,8 @@ def provider_name_from_id(id: str) -> str:
377
346
  return "SiliconFlow"
378
347
  case ModelProviderName.cerebras:
379
348
  return "Cerebras"
349
+ case ModelProviderName.docker_model_runner:
350
+ return "Docker Model Runner"
380
351
  case _:
381
352
  # triggers pyright warning if I miss a case
382
353
  raise_exhaustive_enum_error(enum_id)
@@ -444,3 +415,190 @@ provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
444
415
  message="Attempted to use Cerebras without an API key set. \nGet your API key from https://cloud.cerebras.ai/platform",
445
416
  ),
446
417
  }
418
+
419
+
420
+ class LiteLlmCoreConfig(BaseModel):
421
+ base_url: str | None = None
422
+ default_headers: Dict[str, str] | None = None
423
+ additional_body_options: Dict[str, str] | None = None
424
+
425
+
426
+ def lite_llm_core_config_for_provider(
427
+ provider_name: ModelProviderName,
428
+ openai_compatible_provider_name: str | None = None,
429
+ ) -> LiteLlmCoreConfig | None:
430
+ """
431
+ Returns a LiteLLM core config for a given provider.
432
+
433
+ Args:
434
+ provider_name: The provider to get the config for
435
+ openai_compatible_provider_name: Required for openai compatible providers, this is the name of the underlying provider
436
+ """
437
+ match provider_name:
438
+ case ModelProviderName.openrouter:
439
+ return LiteLlmCoreConfig(
440
+ base_url=(
441
+ os.getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
442
+ ),
443
+ default_headers={
444
+ "HTTP-Referer": "https://kiln.tech/openrouter",
445
+ "X-Title": "KilnAI",
446
+ },
447
+ additional_body_options={
448
+ "api_key": Config.shared().open_router_api_key,
449
+ },
450
+ )
451
+ case ModelProviderName.siliconflow_cn:
452
+ return LiteLlmCoreConfig(
453
+ base_url=os.getenv("SILICONFLOW_BASE_URL")
454
+ or "https://api.siliconflow.cn/v1",
455
+ default_headers={
456
+ "HTTP-Referer": "https://kiln.tech/siliconflow",
457
+ "X-Title": "KilnAI",
458
+ },
459
+ additional_body_options={
460
+ "api_key": Config.shared().siliconflow_cn_api_key,
461
+ },
462
+ )
463
+ case ModelProviderName.openai:
464
+ return LiteLlmCoreConfig(
465
+ additional_body_options={
466
+ "api_key": Config.shared().open_ai_api_key,
467
+ },
468
+ )
469
+ case ModelProviderName.groq:
470
+ return LiteLlmCoreConfig(
471
+ additional_body_options={
472
+ "api_key": Config.shared().groq_api_key,
473
+ },
474
+ )
475
+ case ModelProviderName.amazon_bedrock:
476
+ return LiteLlmCoreConfig(
477
+ additional_body_options={
478
+ "aws_access_key_id": Config.shared().bedrock_access_key,
479
+ "aws_secret_access_key": Config.shared().bedrock_secret_key,
480
+ # The only region that's widely supported for bedrock
481
+ "aws_region_name": "us-west-2",
482
+ },
483
+ )
484
+ case ModelProviderName.ollama:
485
+ # Set the Ollama base URL for 2 reasons:
486
+ # 1. To use the correct base URL
487
+ # 2. We use Ollama's OpenAI compatible API (/v1), and don't just let litellm use the Ollama API. We use more advanced features like json_schema.
488
+ ollama_base_url = (
489
+ Config.shared().ollama_base_url or "http://localhost:11434"
490
+ )
491
+ return LiteLlmCoreConfig(
492
+ base_url=ollama_base_url + "/v1",
493
+ additional_body_options={
494
+ # LiteLLM errors without an api_key, even though Ollama doesn't support one
495
+ "api_key": "NA",
496
+ },
497
+ )
498
+ case ModelProviderName.docker_model_runner:
499
+ docker_base_url = (
500
+ Config.shared().docker_model_runner_base_url
501
+ or "http://localhost:12434/engines/llama.cpp"
502
+ )
503
+ return LiteLlmCoreConfig(
504
+ # Docker Model Runner uses OpenAI-compatible API at /v1 endpoint
505
+ base_url=docker_base_url + "/v1",
506
+ additional_body_options={
507
+ # LiteLLM errors without an api_key, even though Docker Model Runner doesn't require one.
508
+ "api_key": "DMR",
509
+ },
510
+ )
511
+ case ModelProviderName.fireworks_ai:
512
+ return LiteLlmCoreConfig(
513
+ additional_body_options={
514
+ "api_key": Config.shared().fireworks_api_key,
515
+ },
516
+ )
517
+ case ModelProviderName.anthropic:
518
+ return LiteLlmCoreConfig(
519
+ additional_body_options={
520
+ "api_key": Config.shared().anthropic_api_key,
521
+ },
522
+ )
523
+ case ModelProviderName.gemini_api:
524
+ return LiteLlmCoreConfig(
525
+ additional_body_options={
526
+ "api_key": Config.shared().gemini_api_key,
527
+ },
528
+ )
529
+ case ModelProviderName.vertex:
530
+ return LiteLlmCoreConfig(
531
+ additional_body_options={
532
+ "vertex_project": Config.shared().vertex_project_id,
533
+ "vertex_location": Config.shared().vertex_location,
534
+ },
535
+ )
536
+ case ModelProviderName.together_ai:
537
+ return LiteLlmCoreConfig(
538
+ additional_body_options={
539
+ "api_key": Config.shared().together_api_key,
540
+ },
541
+ )
542
+ case ModelProviderName.azure_openai:
543
+ return LiteLlmCoreConfig(
544
+ base_url=Config.shared().azure_openai_endpoint,
545
+ additional_body_options={
546
+ "api_key": Config.shared().azure_openai_api_key,
547
+ "api_version": "2025-02-01-preview",
548
+ },
549
+ )
550
+ case ModelProviderName.huggingface:
551
+ return LiteLlmCoreConfig(
552
+ additional_body_options={
553
+ "api_key": Config.shared().huggingface_api_key,
554
+ },
555
+ )
556
+ case ModelProviderName.cerebras:
557
+ return LiteLlmCoreConfig(
558
+ additional_body_options={
559
+ "api_key": Config.shared().cerebras_api_key,
560
+ },
561
+ )
562
+ case ModelProviderName.openai_compatible:
563
+ # openai compatible requires a model name in the format "provider::model_name"
564
+ if openai_compatible_provider_name is None:
565
+ raise ValueError("OpenAI compatible provider requires a provider name")
566
+
567
+ openai_compatible_providers = (
568
+ Config.shared().openai_compatible_providers or []
569
+ )
570
+
571
+ provider = next(
572
+ filter(
573
+ lambda p: p.get("name") == openai_compatible_provider_name,
574
+ openai_compatible_providers,
575
+ ),
576
+ None,
577
+ )
578
+
579
+ if provider is None:
580
+ raise ValueError(
581
+ f"OpenAI compatible provider {openai_compatible_provider_name} not found"
582
+ )
583
+
584
+ # API key optional - some providers like Ollama don't use it, but LiteLLM errors without one
585
+ api_key = provider.get("api_key") or "NA"
586
+ base_url = provider.get("base_url")
587
+ if base_url is None:
588
+ raise ValueError(
589
+ f"OpenAI compatible provider {openai_compatible_provider_name} has no base URL"
590
+ )
591
+
592
+ return LiteLlmCoreConfig(
593
+ base_url=base_url,
594
+ additional_body_options={
595
+ "api_key": api_key,
596
+ },
597
+ )
598
+ # These are virtual providers that should have mapped to an actual provider upstream (using core_provider method)
599
+ case ModelProviderName.kiln_fine_tune:
600
+ return None
601
+ case ModelProviderName.kiln_custom_registry:
602
+ return None
603
+ case _:
604
+ raise_exhaustive_enum_error(provider_name)
@@ -0,0 +1,49 @@
1
+ from collections import defaultdict
2
+ from typing import DefaultDict
3
+
4
+ from kiln_ai.datamodel.chunk import ChunkedDocument
5
+ from kiln_ai.datamodel.embedding import ChunkEmbeddings
6
+ from kiln_ai.datamodel.extraction import Document, Extraction
7
+
8
+
9
+ def deduplicate_extractions(items: list[Extraction]) -> list[Extraction]:
10
+ grouped_items: DefaultDict[str, list[Extraction]] = defaultdict(list)
11
+ for item in items:
12
+ if item.extractor_config_id is None:
13
+ raise ValueError("Extractor config ID is required")
14
+ grouped_items[item.extractor_config_id].append(item)
15
+ return [min(group, key=lambda x: x.created_at) for group in grouped_items.values()]
16
+
17
+
18
+ def deduplicate_chunked_documents(
19
+ items: list[ChunkedDocument],
20
+ ) -> list[ChunkedDocument]:
21
+ grouped_items: DefaultDict[str, list[ChunkedDocument]] = defaultdict(list)
22
+ for item in items:
23
+ if item.chunker_config_id is None:
24
+ raise ValueError("Chunker config ID is required")
25
+ grouped_items[item.chunker_config_id].append(item)
26
+ return [min(group, key=lambda x: x.created_at) for group in grouped_items.values()]
27
+
28
+
29
+ def deduplicate_chunk_embeddings(items: list[ChunkEmbeddings]) -> list[ChunkEmbeddings]:
30
+ grouped_items: DefaultDict[str, list[ChunkEmbeddings]] = defaultdict(list)
31
+ for item in items:
32
+ if item.embedding_config_id is None:
33
+ raise ValueError("Embedding config ID is required")
34
+ grouped_items[item.embedding_config_id].append(item)
35
+ return [min(group, key=lambda x: x.created_at) for group in grouped_items.values()]
36
+
37
+
38
+ def filter_documents_by_tags(
39
+ documents: list[Document], tags: list[str] | None
40
+ ) -> list[Document]:
41
+ if not tags:
42
+ return documents
43
+
44
+ filtered_documents = []
45
+ for document in documents:
46
+ if document.tags and any(tag in document.tags for tag in tags):
47
+ filtered_documents.append(document)
48
+
49
+ return filtered_documents