kiln-ai 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +8 -2
- kiln_ai/adapters/adapter_registry.py +43 -208
- kiln_ai/adapters/chat/chat_formatter.py +8 -12
- kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/docker_model_runner_tools.py +119 -0
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/base_eval.py +2 -2
- kiln_ai/adapters/eval/eval_runner.py +9 -3
- kiln_ai/adapters/eval/g_eval.py +2 -2
- kiln_ai/adapters/eval/test_base_eval.py +2 -4
- kiln_ai/adapters/eval/test_g_eval.py +4 -5
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
- kiln_ai/adapters/fine_tune/__init__.py +1 -1
- kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +192 -0
- kiln_ai/adapters/ml_model_list.py +761 -37
- kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
- kiln_ai/adapters/model_adapters/litellm_adapter.py +380 -138
- kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -2
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
- kiln_ai/adapters/model_adapters/test_structured_output.py +113 -5
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/parsers/__init__.py +1 -1
- kiln_ai/adapters/provider_tools.py +205 -47
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/repair/test_repair_task.py +12 -9
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +657 -85
- kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
- kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
- kiln_ai/adapters/test_ml_model_list.py +251 -1
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_adaptors.py +13 -6
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +254 -8
- kiln_ai/adapters/test_remote_config.py +651 -58
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +39 -34
- kiln_ai/datamodel/basemodel.py +170 -1
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +28 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/eval.py +1 -1
- kiln_ai/datamodel/external_tool_server.py +298 -0
- kiln_ai/datamodel/extraction.py +303 -0
- kiln_ai/datamodel/json_schema.py +25 -10
- kiln_ai/datamodel/project.py +40 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/registry.py +0 -15
- kiln_ai/datamodel/run_config.py +62 -0
- kiln_ai/datamodel/task.py +2 -77
- kiln_ai/datamodel/task_output.py +6 -1
- kiln_ai/datamodel/task_run.py +41 -0
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +4 -4
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_example_models.py +175 -0
- kiln_ai/datamodel/test_external_tool_server.py +691 -0
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +470 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_registry.py +8 -3
- kiln_ai/datamodel/test_task.py +15 -47
- kiln_ai/datamodel/test_tool_id.py +320 -0
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +105 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/__init__.py +8 -0
- kiln_ai/tools/base_tool.py +82 -0
- kiln_ai/tools/built_in_tools/__init__.py +13 -0
- kiln_ai/tools/built_in_tools/math_tools.py +124 -0
- kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
- kiln_ai/tools/mcp_server_tool.py +95 -0
- kiln_ai/tools/mcp_session_manager.py +246 -0
- kiln_ai/tools/rag_tools.py +157 -0
- kiln_ai/tools/test_base_tools.py +199 -0
- kiln_ai/tools/test_mcp_server_tool.py +457 -0
- kiln_ai/tools/test_mcp_session_manager.py +1585 -0
- kiln_ai/tools/test_rag_tools.py +848 -0
- kiln_ai/tools/test_tool_registry.py +562 -0
- kiln_ai/tools/tool_registry.py +85 -0
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +24 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +94 -0
- kiln_ai/utils/pdf_utils.py +38 -0
- kiln_ai/utils/project_utils.py +17 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_config.py +138 -1
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +131 -0
- kiln_ai/utils/test_pdf_utils.py +73 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +12 -5
- kiln_ai-0.21.0.dist-info/RECORD +211 -0
- kiln_ai-0.19.0.dist-info/RECORD +0 -115
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -13,7 +13,7 @@ from kiln_ai.datamodel import (
|
|
|
13
13
|
Task,
|
|
14
14
|
Usage,
|
|
15
15
|
)
|
|
16
|
-
from kiln_ai.datamodel.task import
|
|
16
|
+
from kiln_ai.datamodel.task import RunConfigProperties
|
|
17
17
|
from kiln_ai.utils.config import Config
|
|
18
18
|
|
|
19
19
|
|
|
@@ -41,8 +41,8 @@ def test_task(tmp_path):
|
|
|
41
41
|
@pytest.fixture
|
|
42
42
|
def adapter(test_task):
|
|
43
43
|
return MockAdapter(
|
|
44
|
-
|
|
45
|
-
|
|
44
|
+
task=test_task,
|
|
45
|
+
run_config=RunConfigProperties(
|
|
46
46
|
model_name="phi_3_5",
|
|
47
47
|
model_provider_name="ollama",
|
|
48
48
|
prompt_id="simple_chain_of_thought_prompt_builder",
|
|
@@ -240,8 +240,8 @@ async def test_autosave_true(test_task, adapter):
|
|
|
240
240
|
def test_properties_for_task_output_custom_values(test_task):
|
|
241
241
|
"""Test that _properties_for_task_output includes custom temperature, top_p, and structured_output_mode"""
|
|
242
242
|
adapter = MockAdapter(
|
|
243
|
-
|
|
244
|
-
|
|
243
|
+
task=test_task,
|
|
244
|
+
run_config=RunConfigProperties(
|
|
245
245
|
model_name="gpt-4",
|
|
246
246
|
model_provider_name="openai",
|
|
247
247
|
prompt_id="simple_prompt_builder",
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from typing import Dict
|
|
4
|
+
from unittest.mock import Mock, patch
|
|
4
5
|
|
|
5
6
|
import pytest
|
|
7
|
+
from litellm.types.utils import ModelResponse
|
|
6
8
|
|
|
7
9
|
import kiln_ai.datamodel as datamodel
|
|
8
10
|
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
@@ -11,7 +13,7 @@ from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput,
|
|
|
11
13
|
from kiln_ai.adapters.ollama_tools import ollama_online
|
|
12
14
|
from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
|
|
13
15
|
from kiln_ai.datamodel import PromptId
|
|
14
|
-
from kiln_ai.datamodel.task import
|
|
16
|
+
from kiln_ai.datamodel.task import RunConfigProperties
|
|
15
17
|
from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
|
|
16
18
|
|
|
17
19
|
|
|
@@ -40,8 +42,8 @@ async def test_structured_output_ollama(tmp_path, model_name):
|
|
|
40
42
|
class MockAdapter(BaseAdapter):
|
|
41
43
|
def __init__(self, kiln_task: datamodel.Task, response: Dict | str | None):
|
|
42
44
|
super().__init__(
|
|
43
|
-
|
|
44
|
-
|
|
45
|
+
task=kiln_task,
|
|
46
|
+
run_config=RunConfigProperties(
|
|
45
47
|
model_name="phi_3_5",
|
|
46
48
|
model_provider_name="ollama",
|
|
47
49
|
prompt_id="simple_chain_of_thought_prompt_builder",
|
|
@@ -259,6 +261,7 @@ async def run_structured_input_task(
|
|
|
259
261
|
model_name: str,
|
|
260
262
|
provider: str,
|
|
261
263
|
prompt_id: PromptId,
|
|
264
|
+
verify_trace_cot: bool = False,
|
|
262
265
|
):
|
|
263
266
|
response, a, run = await run_structured_input_task_no_validation(
|
|
264
267
|
task, model_name, provider, prompt_id
|
|
@@ -282,6 +285,32 @@ async def run_structured_input_task(
|
|
|
282
285
|
assert "reasoning" in run.intermediate_outputs
|
|
283
286
|
assert isinstance(run.intermediate_outputs["reasoning"], str)
|
|
284
287
|
|
|
288
|
+
# Check the trace
|
|
289
|
+
trace = run.trace
|
|
290
|
+
assert trace is not None
|
|
291
|
+
if verify_trace_cot:
|
|
292
|
+
assert len(trace) == 5
|
|
293
|
+
assert trace[0]["role"] == "system"
|
|
294
|
+
assert "You are an assistant which classifies a triangle" in trace[0]["content"]
|
|
295
|
+
assert trace[1]["role"] == "user"
|
|
296
|
+
assert trace[2]["role"] == "assistant"
|
|
297
|
+
assert trace[2].get("tool_calls") is None
|
|
298
|
+
assert trace[3]["role"] == "user"
|
|
299
|
+
assert trace[4]["role"] == "assistant"
|
|
300
|
+
assert trace[4].get("tool_calls") is None
|
|
301
|
+
else:
|
|
302
|
+
assert len(trace) == 3
|
|
303
|
+
assert trace[0]["role"] == "system"
|
|
304
|
+
assert "You are an assistant which classifies a triangle" in trace[0]["content"]
|
|
305
|
+
assert trace[1]["role"] == "user"
|
|
306
|
+
json_content = json.loads(trace[1]["content"])
|
|
307
|
+
assert json_content["a"] == 2
|
|
308
|
+
assert json_content["b"] == 2
|
|
309
|
+
assert json_content["c"] == 2
|
|
310
|
+
assert trace[2]["role"] == "assistant"
|
|
311
|
+
assert trace[2].get("tool_calls") is None
|
|
312
|
+
assert "[[equilateral]]" in trace[2]["content"]
|
|
313
|
+
|
|
285
314
|
|
|
286
315
|
@pytest.mark.paid
|
|
287
316
|
async def test_structured_input_gpt_4o_mini(tmp_path):
|
|
@@ -299,15 +328,94 @@ async def test_all_built_in_models_structured_input(
|
|
|
299
328
|
)
|
|
300
329
|
|
|
301
330
|
|
|
331
|
+
async def test_all_built_in_models_structured_input_mocked(tmp_path):
|
|
332
|
+
mock_response = ModelResponse(
|
|
333
|
+
model="gpt-4o-mini",
|
|
334
|
+
choices=[
|
|
335
|
+
{
|
|
336
|
+
"message": {
|
|
337
|
+
"content": "The answer is [[equilateral]]",
|
|
338
|
+
}
|
|
339
|
+
}
|
|
340
|
+
],
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Mock the Config.shared() method to return a mock config with required attributes
|
|
344
|
+
mock_config = Mock()
|
|
345
|
+
mock_config.open_ai_api_key = "mock_api_key"
|
|
346
|
+
mock_config.user_id = "test_user"
|
|
347
|
+
mock_config.groq_api_key = "mock_api_key"
|
|
348
|
+
|
|
349
|
+
with (
|
|
350
|
+
patch(
|
|
351
|
+
"litellm.acompletion",
|
|
352
|
+
side_effect=[mock_response],
|
|
353
|
+
),
|
|
354
|
+
patch("kiln_ai.utils.config.Config.shared", return_value=mock_config),
|
|
355
|
+
):
|
|
356
|
+
await run_structured_input_test(
|
|
357
|
+
tmp_path, "llama_3_1_8b", "groq", "simple_prompt_builder"
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
|
|
302
361
|
@pytest.mark.paid
|
|
303
362
|
@pytest.mark.ollama
|
|
304
363
|
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
305
364
|
async def test_structured_input_cot_prompt_builder(tmp_path, model_name, provider_name):
|
|
306
365
|
task = build_structured_input_test_task(tmp_path)
|
|
307
366
|
await run_structured_input_task(
|
|
308
|
-
task,
|
|
367
|
+
task,
|
|
368
|
+
model_name,
|
|
369
|
+
provider_name,
|
|
370
|
+
"simple_chain_of_thought_prompt_builder",
|
|
371
|
+
verify_trace_cot=True,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
async def test_structured_input_cot_prompt_builder_mocked(tmp_path):
|
|
376
|
+
task = build_structured_input_test_task(tmp_path)
|
|
377
|
+
mock_response_1 = ModelResponse(
|
|
378
|
+
model="gpt-4o-mini",
|
|
379
|
+
choices=[
|
|
380
|
+
{
|
|
381
|
+
"message": {
|
|
382
|
+
"content": "I'm thinking real hard... oh!",
|
|
383
|
+
}
|
|
384
|
+
}
|
|
385
|
+
],
|
|
386
|
+
)
|
|
387
|
+
mock_response_2 = ModelResponse(
|
|
388
|
+
model="gpt-4o-mini",
|
|
389
|
+
choices=[
|
|
390
|
+
{
|
|
391
|
+
"message": {
|
|
392
|
+
"content": "After thinking, I've decided the answer is [[equilateral]]",
|
|
393
|
+
}
|
|
394
|
+
}
|
|
395
|
+
],
|
|
309
396
|
)
|
|
310
397
|
|
|
398
|
+
# Mock the Config.shared() method to return a mock config with required attributes
|
|
399
|
+
mock_config = Mock()
|
|
400
|
+
mock_config.open_ai_api_key = "mock_api_key"
|
|
401
|
+
mock_config.user_id = "test_user"
|
|
402
|
+
mock_config.groq_api_key = "mock_api_key"
|
|
403
|
+
|
|
404
|
+
with (
|
|
405
|
+
patch(
|
|
406
|
+
"litellm.acompletion",
|
|
407
|
+
side_effect=[mock_response_1, mock_response_2],
|
|
408
|
+
),
|
|
409
|
+
patch("kiln_ai.utils.config.Config.shared", return_value=mock_config),
|
|
410
|
+
):
|
|
411
|
+
await run_structured_input_task(
|
|
412
|
+
task,
|
|
413
|
+
"llama_3_1_8b",
|
|
414
|
+
"groq",
|
|
415
|
+
"simple_chain_of_thought_prompt_builder",
|
|
416
|
+
verify_trace_cot=True,
|
|
417
|
+
)
|
|
418
|
+
|
|
311
419
|
|
|
312
420
|
@pytest.mark.paid
|
|
313
421
|
@pytest.mark.ollama
|
|
@@ -350,7 +458,7 @@ When asked for a final result, this is the format (for an equilateral example):
|
|
|
350
458
|
"""
|
|
351
459
|
task.output_json_schema = json.dumps(triangle_schema)
|
|
352
460
|
task.save_to_file()
|
|
353
|
-
response,
|
|
461
|
+
response, _, _ = await run_structured_input_task_no_validation(
|
|
354
462
|
task, model_name, provider_name, "simple_chain_of_thought_prompt_builder"
|
|
355
463
|
)
|
|
356
464
|
|
kiln_ai/adapters/ollama_tools.py
CHANGED
|
@@ -4,6 +4,7 @@ import httpx
|
|
|
4
4
|
import requests
|
|
5
5
|
from pydantic import BaseModel, Field
|
|
6
6
|
|
|
7
|
+
from kiln_ai.adapters.ml_embedding_model_list import built_in_embedding_models
|
|
7
8
|
from kiln_ai.adapters.ml_model_list import ModelProviderName, built_in_models
|
|
8
9
|
from kiln_ai.utils.config import Config
|
|
9
10
|
|
|
@@ -41,22 +42,28 @@ class OllamaConnection(BaseModel):
|
|
|
41
42
|
version: str | None = None
|
|
42
43
|
supported_models: List[str]
|
|
43
44
|
untested_models: List[str] = Field(default_factory=list)
|
|
45
|
+
supported_embedding_models: List[str] = Field(default_factory=list)
|
|
44
46
|
|
|
45
47
|
def all_models(self) -> List[str]:
|
|
46
48
|
return self.supported_models + self.untested_models
|
|
47
49
|
|
|
50
|
+
def all_embedding_models(self) -> List[str]:
|
|
51
|
+
return self.supported_embedding_models
|
|
52
|
+
|
|
48
53
|
|
|
49
54
|
# Parse the Ollama /api/tags response
|
|
50
|
-
def parse_ollama_tags(tags: Any) -> OllamaConnection
|
|
55
|
+
def parse_ollama_tags(tags: Any) -> OllamaConnection:
|
|
51
56
|
# Build a list of models we support for Ollama from the built-in model list
|
|
52
|
-
supported_ollama_models =
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
57
|
+
supported_ollama_models = set(
|
|
58
|
+
[
|
|
59
|
+
provider.model_id
|
|
60
|
+
for model in built_in_models
|
|
61
|
+
for provider in model.providers
|
|
62
|
+
if provider.name == ModelProviderName.ollama
|
|
63
|
+
]
|
|
64
|
+
)
|
|
58
65
|
# Append model_aliases to supported_ollama_models
|
|
59
|
-
supported_ollama_models.
|
|
66
|
+
supported_ollama_models.update(
|
|
60
67
|
[
|
|
61
68
|
alias
|
|
62
69
|
for model in built_in_models
|
|
@@ -65,16 +72,44 @@ def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
|
|
|
65
72
|
]
|
|
66
73
|
)
|
|
67
74
|
|
|
75
|
+
supported_ollama_embedding_models = set(
|
|
76
|
+
[
|
|
77
|
+
provider.model_id
|
|
78
|
+
for model in built_in_embedding_models
|
|
79
|
+
for provider in model.providers
|
|
80
|
+
if provider.name == ModelProviderName.ollama
|
|
81
|
+
]
|
|
82
|
+
)
|
|
83
|
+
supported_ollama_embedding_models.update(
|
|
84
|
+
[
|
|
85
|
+
alias
|
|
86
|
+
for model in built_in_embedding_models
|
|
87
|
+
for provider in model.providers
|
|
88
|
+
for alias in provider.ollama_model_aliases or []
|
|
89
|
+
]
|
|
90
|
+
)
|
|
91
|
+
|
|
68
92
|
if "models" in tags:
|
|
69
93
|
models = tags["models"]
|
|
70
94
|
if isinstance(models, list):
|
|
71
95
|
model_names = [model["model"] for model in models]
|
|
72
96
|
available_supported_models = []
|
|
73
97
|
untested_models = []
|
|
74
|
-
supported_models_latest_aliases =
|
|
75
|
-
f"{m}:latest" for m in supported_ollama_models
|
|
76
|
-
|
|
98
|
+
supported_models_latest_aliases = set(
|
|
99
|
+
[f"{m}:latest" for m in supported_ollama_models]
|
|
100
|
+
)
|
|
101
|
+
supported_embedding_models_latest_aliases = set(
|
|
102
|
+
[f"{m}:latest" for m in supported_ollama_embedding_models]
|
|
103
|
+
)
|
|
104
|
+
|
|
77
105
|
for model in model_names:
|
|
106
|
+
# Skip embedding models - they should only appear in supported_embedding_models
|
|
107
|
+
if (
|
|
108
|
+
model in supported_ollama_embedding_models
|
|
109
|
+
or model in supported_embedding_models_latest_aliases
|
|
110
|
+
):
|
|
111
|
+
continue
|
|
112
|
+
|
|
78
113
|
if (
|
|
79
114
|
model in supported_ollama_models
|
|
80
115
|
or model in supported_models_latest_aliases
|
|
@@ -83,17 +118,31 @@ def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
|
|
|
83
118
|
else:
|
|
84
119
|
untested_models.append(model)
|
|
85
120
|
|
|
86
|
-
|
|
121
|
+
available_supported_embedding_models = []
|
|
122
|
+
for model in model_names:
|
|
123
|
+
if (
|
|
124
|
+
model in supported_ollama_embedding_models
|
|
125
|
+
or model in supported_embedding_models_latest_aliases
|
|
126
|
+
):
|
|
127
|
+
available_supported_embedding_models.append(model)
|
|
128
|
+
|
|
129
|
+
if (
|
|
130
|
+
available_supported_models
|
|
131
|
+
or untested_models
|
|
132
|
+
or available_supported_embedding_models
|
|
133
|
+
):
|
|
87
134
|
return OllamaConnection(
|
|
88
135
|
message="Ollama connected",
|
|
89
136
|
supported_models=available_supported_models,
|
|
90
137
|
untested_models=untested_models,
|
|
138
|
+
supported_embedding_models=available_supported_embedding_models,
|
|
91
139
|
)
|
|
92
140
|
|
|
93
141
|
return OllamaConnection(
|
|
94
142
|
message="Ollama is running, but no supported models are installed. Install one or more supported model, like 'ollama pull phi3.5'.",
|
|
95
143
|
supported_models=[],
|
|
96
144
|
untested_models=[],
|
|
145
|
+
supported_embedding_models=[],
|
|
97
146
|
)
|
|
98
147
|
|
|
99
148
|
|
|
@@ -113,3 +162,11 @@ async def get_ollama_connection() -> OllamaConnection | None:
|
|
|
113
162
|
def ollama_model_installed(conn: OllamaConnection, model_name: str) -> bool:
|
|
114
163
|
all_models = conn.all_models()
|
|
115
164
|
return model_name in all_models or f"{model_name}:latest" in all_models
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def ollama_embedding_model_installed(conn: OllamaConnection, model_name: str) -> bool:
|
|
168
|
+
all_embedding_models = conn.all_embedding_models()
|
|
169
|
+
return (
|
|
170
|
+
model_name in all_embedding_models
|
|
171
|
+
or f"{model_name}:latest" in all_embedding_models
|
|
172
|
+
)
|
|
@@ -1,7 +1,13 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import os
|
|
2
3
|
from dataclasses import dataclass
|
|
3
4
|
from typing import Dict, List
|
|
4
5
|
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from kiln_ai.adapters.docker_model_runner_tools import (
|
|
9
|
+
get_docker_model_runner_connection,
|
|
10
|
+
)
|
|
5
11
|
from kiln_ai.adapters.ml_model_list import (
|
|
6
12
|
KilnModel,
|
|
7
13
|
KilnModelProvider,
|
|
@@ -10,14 +16,12 @@ from kiln_ai.adapters.ml_model_list import (
|
|
|
10
16
|
StructuredOutputMode,
|
|
11
17
|
built_in_models,
|
|
12
18
|
)
|
|
13
|
-
from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
|
|
14
19
|
from kiln_ai.adapters.ollama_tools import get_ollama_connection
|
|
15
20
|
from kiln_ai.datamodel import Finetune, Task
|
|
16
21
|
from kiln_ai.datamodel.datamodel_enums import ChatStrategy
|
|
17
|
-
from kiln_ai.datamodel.registry import project_from_id
|
|
18
|
-
from kiln_ai.datamodel.task import RunConfigProperties
|
|
19
22
|
from kiln_ai.utils.config import Config
|
|
20
23
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
24
|
+
from kiln_ai.utils.project_utils import project_from_id
|
|
21
25
|
|
|
22
26
|
logger = logging.getLogger(__name__)
|
|
23
27
|
|
|
@@ -32,6 +36,15 @@ async def provider_enabled(provider_name: ModelProviderName) -> bool:
|
|
|
32
36
|
except Exception:
|
|
33
37
|
return False
|
|
34
38
|
|
|
39
|
+
if provider_name == ModelProviderName.docker_model_runner:
|
|
40
|
+
try:
|
|
41
|
+
conn = await get_docker_model_runner_connection()
|
|
42
|
+
return conn is not None and (
|
|
43
|
+
len(conn.supported_models) > 0 or len(conn.untested_models) > 0
|
|
44
|
+
)
|
|
45
|
+
except Exception:
|
|
46
|
+
return False
|
|
47
|
+
|
|
35
48
|
provider_warning = provider_warnings.get(provider_name)
|
|
36
49
|
if provider_warning is None:
|
|
37
50
|
return False
|
|
@@ -180,50 +193,6 @@ def kiln_model_provider_from(
|
|
|
180
193
|
)
|
|
181
194
|
|
|
182
195
|
|
|
183
|
-
def lite_llm_config_for_openai_compatible(
|
|
184
|
-
run_config_properties: RunConfigProperties,
|
|
185
|
-
) -> LiteLlmConfig:
|
|
186
|
-
model_id = run_config_properties.model_name
|
|
187
|
-
try:
|
|
188
|
-
openai_provider_name, model_id = model_id.split("::")
|
|
189
|
-
except Exception:
|
|
190
|
-
raise ValueError(f"Invalid openai compatible model ID: {model_id}")
|
|
191
|
-
|
|
192
|
-
openai_compatible_providers = Config.shared().openai_compatible_providers or []
|
|
193
|
-
provider = next(
|
|
194
|
-
filter(
|
|
195
|
-
lambda p: p.get("name") == openai_provider_name, openai_compatible_providers
|
|
196
|
-
),
|
|
197
|
-
None,
|
|
198
|
-
)
|
|
199
|
-
if provider is None:
|
|
200
|
-
raise ValueError(f"OpenAI compatible provider {openai_provider_name} not found")
|
|
201
|
-
|
|
202
|
-
# API key optional - some providers like Ollama don't use it, but LiteLLM errors without one
|
|
203
|
-
api_key = provider.get("api_key") or "NA"
|
|
204
|
-
base_url = provider.get("base_url")
|
|
205
|
-
if base_url is None:
|
|
206
|
-
raise ValueError(
|
|
207
|
-
f"OpenAI compatible provider {openai_provider_name} has no base URL"
|
|
208
|
-
)
|
|
209
|
-
|
|
210
|
-
# Update a copy of the run config properties to use the openai compatible provider
|
|
211
|
-
updated_run_config_properties = run_config_properties.model_copy(deep=True)
|
|
212
|
-
updated_run_config_properties.model_provider_name = (
|
|
213
|
-
ModelProviderName.openai_compatible
|
|
214
|
-
)
|
|
215
|
-
updated_run_config_properties.model_name = model_id
|
|
216
|
-
|
|
217
|
-
return LiteLlmConfig(
|
|
218
|
-
# OpenAI compatible, with a custom base URL
|
|
219
|
-
run_config_properties=updated_run_config_properties,
|
|
220
|
-
base_url=base_url,
|
|
221
|
-
additional_body_options={
|
|
222
|
-
"api_key": api_key,
|
|
223
|
-
},
|
|
224
|
-
)
|
|
225
|
-
|
|
226
|
-
|
|
227
196
|
def lite_llm_provider_model(
|
|
228
197
|
model_id: str,
|
|
229
198
|
) -> KilnModelProvider:
|
|
@@ -377,6 +346,8 @@ def provider_name_from_id(id: str) -> str:
|
|
|
377
346
|
return "SiliconFlow"
|
|
378
347
|
case ModelProviderName.cerebras:
|
|
379
348
|
return "Cerebras"
|
|
349
|
+
case ModelProviderName.docker_model_runner:
|
|
350
|
+
return "Docker Model Runner"
|
|
380
351
|
case _:
|
|
381
352
|
# triggers pyright warning if I miss a case
|
|
382
353
|
raise_exhaustive_enum_error(enum_id)
|
|
@@ -444,3 +415,190 @@ provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
|
|
|
444
415
|
message="Attempted to use Cerebras without an API key set. \nGet your API key from https://cloud.cerebras.ai/platform",
|
|
445
416
|
),
|
|
446
417
|
}
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
class LiteLlmCoreConfig(BaseModel):
|
|
421
|
+
base_url: str | None = None
|
|
422
|
+
default_headers: Dict[str, str] | None = None
|
|
423
|
+
additional_body_options: Dict[str, str] | None = None
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def lite_llm_core_config_for_provider(
|
|
427
|
+
provider_name: ModelProviderName,
|
|
428
|
+
openai_compatible_provider_name: str | None = None,
|
|
429
|
+
) -> LiteLlmCoreConfig | None:
|
|
430
|
+
"""
|
|
431
|
+
Returns a LiteLLM core config for a given provider.
|
|
432
|
+
|
|
433
|
+
Args:
|
|
434
|
+
provider_name: The provider to get the config for
|
|
435
|
+
openai_compatible_provider_name: Required for openai compatible providers, this is the name of the underlying provider
|
|
436
|
+
"""
|
|
437
|
+
match provider_name:
|
|
438
|
+
case ModelProviderName.openrouter:
|
|
439
|
+
return LiteLlmCoreConfig(
|
|
440
|
+
base_url=(
|
|
441
|
+
os.getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
|
|
442
|
+
),
|
|
443
|
+
default_headers={
|
|
444
|
+
"HTTP-Referer": "https://kiln.tech/openrouter",
|
|
445
|
+
"X-Title": "KilnAI",
|
|
446
|
+
},
|
|
447
|
+
additional_body_options={
|
|
448
|
+
"api_key": Config.shared().open_router_api_key,
|
|
449
|
+
},
|
|
450
|
+
)
|
|
451
|
+
case ModelProviderName.siliconflow_cn:
|
|
452
|
+
return LiteLlmCoreConfig(
|
|
453
|
+
base_url=os.getenv("SILICONFLOW_BASE_URL")
|
|
454
|
+
or "https://api.siliconflow.cn/v1",
|
|
455
|
+
default_headers={
|
|
456
|
+
"HTTP-Referer": "https://kiln.tech/siliconflow",
|
|
457
|
+
"X-Title": "KilnAI",
|
|
458
|
+
},
|
|
459
|
+
additional_body_options={
|
|
460
|
+
"api_key": Config.shared().siliconflow_cn_api_key,
|
|
461
|
+
},
|
|
462
|
+
)
|
|
463
|
+
case ModelProviderName.openai:
|
|
464
|
+
return LiteLlmCoreConfig(
|
|
465
|
+
additional_body_options={
|
|
466
|
+
"api_key": Config.shared().open_ai_api_key,
|
|
467
|
+
},
|
|
468
|
+
)
|
|
469
|
+
case ModelProviderName.groq:
|
|
470
|
+
return LiteLlmCoreConfig(
|
|
471
|
+
additional_body_options={
|
|
472
|
+
"api_key": Config.shared().groq_api_key,
|
|
473
|
+
},
|
|
474
|
+
)
|
|
475
|
+
case ModelProviderName.amazon_bedrock:
|
|
476
|
+
return LiteLlmCoreConfig(
|
|
477
|
+
additional_body_options={
|
|
478
|
+
"aws_access_key_id": Config.shared().bedrock_access_key,
|
|
479
|
+
"aws_secret_access_key": Config.shared().bedrock_secret_key,
|
|
480
|
+
# The only region that's widely supported for bedrock
|
|
481
|
+
"aws_region_name": "us-west-2",
|
|
482
|
+
},
|
|
483
|
+
)
|
|
484
|
+
case ModelProviderName.ollama:
|
|
485
|
+
# Set the Ollama base URL for 2 reasons:
|
|
486
|
+
# 1. To use the correct base URL
|
|
487
|
+
# 2. We use Ollama's OpenAI compatible API (/v1), and don't just let litellm use the Ollama API. We use more advanced features like json_schema.
|
|
488
|
+
ollama_base_url = (
|
|
489
|
+
Config.shared().ollama_base_url or "http://localhost:11434"
|
|
490
|
+
)
|
|
491
|
+
return LiteLlmCoreConfig(
|
|
492
|
+
base_url=ollama_base_url + "/v1",
|
|
493
|
+
additional_body_options={
|
|
494
|
+
# LiteLLM errors without an api_key, even though Ollama doesn't support one
|
|
495
|
+
"api_key": "NA",
|
|
496
|
+
},
|
|
497
|
+
)
|
|
498
|
+
case ModelProviderName.docker_model_runner:
|
|
499
|
+
docker_base_url = (
|
|
500
|
+
Config.shared().docker_model_runner_base_url
|
|
501
|
+
or "http://localhost:12434/engines/llama.cpp"
|
|
502
|
+
)
|
|
503
|
+
return LiteLlmCoreConfig(
|
|
504
|
+
# Docker Model Runner uses OpenAI-compatible API at /v1 endpoint
|
|
505
|
+
base_url=docker_base_url + "/v1",
|
|
506
|
+
additional_body_options={
|
|
507
|
+
# LiteLLM errors without an api_key, even though Docker Model Runner doesn't require one.
|
|
508
|
+
"api_key": "DMR",
|
|
509
|
+
},
|
|
510
|
+
)
|
|
511
|
+
case ModelProviderName.fireworks_ai:
|
|
512
|
+
return LiteLlmCoreConfig(
|
|
513
|
+
additional_body_options={
|
|
514
|
+
"api_key": Config.shared().fireworks_api_key,
|
|
515
|
+
},
|
|
516
|
+
)
|
|
517
|
+
case ModelProviderName.anthropic:
|
|
518
|
+
return LiteLlmCoreConfig(
|
|
519
|
+
additional_body_options={
|
|
520
|
+
"api_key": Config.shared().anthropic_api_key,
|
|
521
|
+
},
|
|
522
|
+
)
|
|
523
|
+
case ModelProviderName.gemini_api:
|
|
524
|
+
return LiteLlmCoreConfig(
|
|
525
|
+
additional_body_options={
|
|
526
|
+
"api_key": Config.shared().gemini_api_key,
|
|
527
|
+
},
|
|
528
|
+
)
|
|
529
|
+
case ModelProviderName.vertex:
|
|
530
|
+
return LiteLlmCoreConfig(
|
|
531
|
+
additional_body_options={
|
|
532
|
+
"vertex_project": Config.shared().vertex_project_id,
|
|
533
|
+
"vertex_location": Config.shared().vertex_location,
|
|
534
|
+
},
|
|
535
|
+
)
|
|
536
|
+
case ModelProviderName.together_ai:
|
|
537
|
+
return LiteLlmCoreConfig(
|
|
538
|
+
additional_body_options={
|
|
539
|
+
"api_key": Config.shared().together_api_key,
|
|
540
|
+
},
|
|
541
|
+
)
|
|
542
|
+
case ModelProviderName.azure_openai:
|
|
543
|
+
return LiteLlmCoreConfig(
|
|
544
|
+
base_url=Config.shared().azure_openai_endpoint,
|
|
545
|
+
additional_body_options={
|
|
546
|
+
"api_key": Config.shared().azure_openai_api_key,
|
|
547
|
+
"api_version": "2025-02-01-preview",
|
|
548
|
+
},
|
|
549
|
+
)
|
|
550
|
+
case ModelProviderName.huggingface:
|
|
551
|
+
return LiteLlmCoreConfig(
|
|
552
|
+
additional_body_options={
|
|
553
|
+
"api_key": Config.shared().huggingface_api_key,
|
|
554
|
+
},
|
|
555
|
+
)
|
|
556
|
+
case ModelProviderName.cerebras:
|
|
557
|
+
return LiteLlmCoreConfig(
|
|
558
|
+
additional_body_options={
|
|
559
|
+
"api_key": Config.shared().cerebras_api_key,
|
|
560
|
+
},
|
|
561
|
+
)
|
|
562
|
+
case ModelProviderName.openai_compatible:
|
|
563
|
+
# openai compatible requires a model name in the format "provider::model_name"
|
|
564
|
+
if openai_compatible_provider_name is None:
|
|
565
|
+
raise ValueError("OpenAI compatible provider requires a provider name")
|
|
566
|
+
|
|
567
|
+
openai_compatible_providers = (
|
|
568
|
+
Config.shared().openai_compatible_providers or []
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
provider = next(
|
|
572
|
+
filter(
|
|
573
|
+
lambda p: p.get("name") == openai_compatible_provider_name,
|
|
574
|
+
openai_compatible_providers,
|
|
575
|
+
),
|
|
576
|
+
None,
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
if provider is None:
|
|
580
|
+
raise ValueError(
|
|
581
|
+
f"OpenAI compatible provider {openai_compatible_provider_name} not found"
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
# API key optional - some providers like Ollama don't use it, but LiteLLM errors without one
|
|
585
|
+
api_key = provider.get("api_key") or "NA"
|
|
586
|
+
base_url = provider.get("base_url")
|
|
587
|
+
if base_url is None:
|
|
588
|
+
raise ValueError(
|
|
589
|
+
f"OpenAI compatible provider {openai_compatible_provider_name} has no base URL"
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
return LiteLlmCoreConfig(
|
|
593
|
+
base_url=base_url,
|
|
594
|
+
additional_body_options={
|
|
595
|
+
"api_key": api_key,
|
|
596
|
+
},
|
|
597
|
+
)
|
|
598
|
+
# These are virtual providers that should have mapped to an actual provider upstream (using core_provider method)
|
|
599
|
+
case ModelProviderName.kiln_fine_tune:
|
|
600
|
+
return None
|
|
601
|
+
case ModelProviderName.kiln_custom_registry:
|
|
602
|
+
return None
|
|
603
|
+
case _:
|
|
604
|
+
raise_exhaustive_enum_error(provider_name)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import DefaultDict
|
|
3
|
+
|
|
4
|
+
from kiln_ai.datamodel.chunk import ChunkedDocument
|
|
5
|
+
from kiln_ai.datamodel.embedding import ChunkEmbeddings
|
|
6
|
+
from kiln_ai.datamodel.extraction import Document, Extraction
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def deduplicate_extractions(items: list[Extraction]) -> list[Extraction]:
|
|
10
|
+
grouped_items: DefaultDict[str, list[Extraction]] = defaultdict(list)
|
|
11
|
+
for item in items:
|
|
12
|
+
if item.extractor_config_id is None:
|
|
13
|
+
raise ValueError("Extractor config ID is required")
|
|
14
|
+
grouped_items[item.extractor_config_id].append(item)
|
|
15
|
+
return [min(group, key=lambda x: x.created_at) for group in grouped_items.values()]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def deduplicate_chunked_documents(
|
|
19
|
+
items: list[ChunkedDocument],
|
|
20
|
+
) -> list[ChunkedDocument]:
|
|
21
|
+
grouped_items: DefaultDict[str, list[ChunkedDocument]] = defaultdict(list)
|
|
22
|
+
for item in items:
|
|
23
|
+
if item.chunker_config_id is None:
|
|
24
|
+
raise ValueError("Chunker config ID is required")
|
|
25
|
+
grouped_items[item.chunker_config_id].append(item)
|
|
26
|
+
return [min(group, key=lambda x: x.created_at) for group in grouped_items.values()]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def deduplicate_chunk_embeddings(items: list[ChunkEmbeddings]) -> list[ChunkEmbeddings]:
|
|
30
|
+
grouped_items: DefaultDict[str, list[ChunkEmbeddings]] = defaultdict(list)
|
|
31
|
+
for item in items:
|
|
32
|
+
if item.embedding_config_id is None:
|
|
33
|
+
raise ValueError("Embedding config ID is required")
|
|
34
|
+
grouped_items[item.embedding_config_id].append(item)
|
|
35
|
+
return [min(group, key=lambda x: x.created_at) for group in grouped_items.values()]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def filter_documents_by_tags(
|
|
39
|
+
documents: list[Document], tags: list[str] | None
|
|
40
|
+
) -> list[Document]:
|
|
41
|
+
if not tags:
|
|
42
|
+
return documents
|
|
43
|
+
|
|
44
|
+
filtered_documents = []
|
|
45
|
+
for document in documents:
|
|
46
|
+
if document.tags and any(tag in document.tags for tag in tags):
|
|
47
|
+
filtered_documents.append(document)
|
|
48
|
+
|
|
49
|
+
return filtered_documents
|