kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +6 -0
- kiln_ai/adapters/adapter_registry.py +43 -226
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/eval_runner.py +6 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -3
- kiln_ai/adapters/eval/test_g_eval.py +1 -1
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +494 -0
- kiln_ai/adapters/ml_model_list.py +876 -18
- kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/provider_tools.py +190 -46
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/test_adapter_registry.py +579 -86
- kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
- kiln_ai/adapters/test_ml_model_list.py +202 -0
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +199 -8
- kiln_ai/adapters/test_remote_config.py +551 -56
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +16 -13
- kiln_ai/datamodel/basemodel.py +201 -4
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +27 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +317 -0
- kiln_ai/datamodel/project.py +33 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +270 -14
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +501 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +187 -1
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +58 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/base_tool.py +12 -3
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +158 -0
- kiln_ai/tools/mcp_server_tool.py +2 -2
- kiln_ai/tools/mcp_session_manager.py +51 -22
- kiln_ai/tools/rag_tools.py +164 -0
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +187 -227
- kiln_ai/tools/test_rag_tools.py +929 -0
- kiln_ai/tools/test_tool_registry.py +290 -7
- kiln_ai/tools/tool_registry.py +69 -16
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +2 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +59 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +86 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
- kiln_ai-0.22.0.dist-info/RECORD +213 -0
- kiln_ai-0.20.1.dist-info/RECORD +0 -138
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
kiln_ai/adapters/ollama_tools.py
CHANGED
|
@@ -4,6 +4,7 @@ import httpx
|
|
|
4
4
|
import requests
|
|
5
5
|
from pydantic import BaseModel, Field
|
|
6
6
|
|
|
7
|
+
from kiln_ai.adapters.ml_embedding_model_list import built_in_embedding_models
|
|
7
8
|
from kiln_ai.adapters.ml_model_list import ModelProviderName, built_in_models
|
|
8
9
|
from kiln_ai.utils.config import Config
|
|
9
10
|
|
|
@@ -41,22 +42,28 @@ class OllamaConnection(BaseModel):
|
|
|
41
42
|
version: str | None = None
|
|
42
43
|
supported_models: List[str]
|
|
43
44
|
untested_models: List[str] = Field(default_factory=list)
|
|
45
|
+
supported_embedding_models: List[str] = Field(default_factory=list)
|
|
44
46
|
|
|
45
47
|
def all_models(self) -> List[str]:
|
|
46
48
|
return self.supported_models + self.untested_models
|
|
47
49
|
|
|
50
|
+
def all_embedding_models(self) -> List[str]:
|
|
51
|
+
return self.supported_embedding_models
|
|
52
|
+
|
|
48
53
|
|
|
49
54
|
# Parse the Ollama /api/tags response
|
|
50
|
-
def parse_ollama_tags(tags: Any) -> OllamaConnection
|
|
55
|
+
def parse_ollama_tags(tags: Any) -> OllamaConnection:
|
|
51
56
|
# Build a list of models we support for Ollama from the built-in model list
|
|
52
|
-
supported_ollama_models =
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
57
|
+
supported_ollama_models = set(
|
|
58
|
+
[
|
|
59
|
+
provider.model_id
|
|
60
|
+
for model in built_in_models
|
|
61
|
+
for provider in model.providers
|
|
62
|
+
if provider.name == ModelProviderName.ollama
|
|
63
|
+
]
|
|
64
|
+
)
|
|
58
65
|
# Append model_aliases to supported_ollama_models
|
|
59
|
-
supported_ollama_models.
|
|
66
|
+
supported_ollama_models.update(
|
|
60
67
|
[
|
|
61
68
|
alias
|
|
62
69
|
for model in built_in_models
|
|
@@ -65,16 +72,44 @@ def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
|
|
|
65
72
|
]
|
|
66
73
|
)
|
|
67
74
|
|
|
75
|
+
supported_ollama_embedding_models = set(
|
|
76
|
+
[
|
|
77
|
+
provider.model_id
|
|
78
|
+
for model in built_in_embedding_models
|
|
79
|
+
for provider in model.providers
|
|
80
|
+
if provider.name == ModelProviderName.ollama
|
|
81
|
+
]
|
|
82
|
+
)
|
|
83
|
+
supported_ollama_embedding_models.update(
|
|
84
|
+
[
|
|
85
|
+
alias
|
|
86
|
+
for model in built_in_embedding_models
|
|
87
|
+
for provider in model.providers
|
|
88
|
+
for alias in provider.ollama_model_aliases or []
|
|
89
|
+
]
|
|
90
|
+
)
|
|
91
|
+
|
|
68
92
|
if "models" in tags:
|
|
69
93
|
models = tags["models"]
|
|
70
94
|
if isinstance(models, list):
|
|
71
95
|
model_names = [model["model"] for model in models]
|
|
72
96
|
available_supported_models = []
|
|
73
97
|
untested_models = []
|
|
74
|
-
supported_models_latest_aliases =
|
|
75
|
-
f"{m}:latest" for m in supported_ollama_models
|
|
76
|
-
|
|
98
|
+
supported_models_latest_aliases = set(
|
|
99
|
+
[f"{m}:latest" for m in supported_ollama_models]
|
|
100
|
+
)
|
|
101
|
+
supported_embedding_models_latest_aliases = set(
|
|
102
|
+
[f"{m}:latest" for m in supported_ollama_embedding_models]
|
|
103
|
+
)
|
|
104
|
+
|
|
77
105
|
for model in model_names:
|
|
106
|
+
# Skip embedding models - they should only appear in supported_embedding_models
|
|
107
|
+
if (
|
|
108
|
+
model in supported_ollama_embedding_models
|
|
109
|
+
or model in supported_embedding_models_latest_aliases
|
|
110
|
+
):
|
|
111
|
+
continue
|
|
112
|
+
|
|
78
113
|
if (
|
|
79
114
|
model in supported_ollama_models
|
|
80
115
|
or model in supported_models_latest_aliases
|
|
@@ -83,17 +118,31 @@ def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
|
|
|
83
118
|
else:
|
|
84
119
|
untested_models.append(model)
|
|
85
120
|
|
|
86
|
-
|
|
121
|
+
available_supported_embedding_models = []
|
|
122
|
+
for model in model_names:
|
|
123
|
+
if (
|
|
124
|
+
model in supported_ollama_embedding_models
|
|
125
|
+
or model in supported_embedding_models_latest_aliases
|
|
126
|
+
):
|
|
127
|
+
available_supported_embedding_models.append(model)
|
|
128
|
+
|
|
129
|
+
if (
|
|
130
|
+
available_supported_models
|
|
131
|
+
or untested_models
|
|
132
|
+
or available_supported_embedding_models
|
|
133
|
+
):
|
|
87
134
|
return OllamaConnection(
|
|
88
135
|
message="Ollama connected",
|
|
89
136
|
supported_models=available_supported_models,
|
|
90
137
|
untested_models=untested_models,
|
|
138
|
+
supported_embedding_models=available_supported_embedding_models,
|
|
91
139
|
)
|
|
92
140
|
|
|
93
141
|
return OllamaConnection(
|
|
94
142
|
message="Ollama is running, but no supported models are installed. Install one or more supported model, like 'ollama pull phi3.5'.",
|
|
95
143
|
supported_models=[],
|
|
96
144
|
untested_models=[],
|
|
145
|
+
supported_embedding_models=[],
|
|
97
146
|
)
|
|
98
147
|
|
|
99
148
|
|
|
@@ -113,3 +162,11 @@ async def get_ollama_connection() -> OllamaConnection | None:
|
|
|
113
162
|
def ollama_model_installed(conn: OllamaConnection, model_name: str) -> bool:
|
|
114
163
|
all_models = conn.all_models()
|
|
115
164
|
return model_name in all_models or f"{model_name}:latest" in all_models
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def ollama_embedding_model_installed(conn: OllamaConnection, model_name: str) -> bool:
|
|
168
|
+
all_embedding_models = conn.all_embedding_models()
|
|
169
|
+
return (
|
|
170
|
+
model_name in all_embedding_models
|
|
171
|
+
or f"{model_name}:latest" in all_embedding_models
|
|
172
|
+
)
|
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import os
|
|
2
3
|
from dataclasses import dataclass
|
|
3
4
|
from typing import Dict, List
|
|
4
5
|
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
5
8
|
from kiln_ai.adapters.docker_model_runner_tools import (
|
|
6
9
|
get_docker_model_runner_connection,
|
|
7
10
|
)
|
|
@@ -13,11 +16,9 @@ from kiln_ai.adapters.ml_model_list import (
|
|
|
13
16
|
StructuredOutputMode,
|
|
14
17
|
built_in_models,
|
|
15
18
|
)
|
|
16
|
-
from kiln_ai.adapters.model_adapters.litellm_config import LiteLlmConfig
|
|
17
19
|
from kiln_ai.adapters.ollama_tools import get_ollama_connection
|
|
18
20
|
from kiln_ai.datamodel import Finetune, Task
|
|
19
21
|
from kiln_ai.datamodel.datamodel_enums import ChatStrategy
|
|
20
|
-
from kiln_ai.datamodel.task import RunConfigProperties
|
|
21
22
|
from kiln_ai.utils.config import Config
|
|
22
23
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
23
24
|
from kiln_ai.utils.project_utils import project_from_id
|
|
@@ -192,50 +193,6 @@ def kiln_model_provider_from(
|
|
|
192
193
|
)
|
|
193
194
|
|
|
194
195
|
|
|
195
|
-
def lite_llm_config_for_openai_compatible(
|
|
196
|
-
run_config_properties: RunConfigProperties,
|
|
197
|
-
) -> LiteLlmConfig:
|
|
198
|
-
model_id = run_config_properties.model_name
|
|
199
|
-
try:
|
|
200
|
-
openai_provider_name, model_id = model_id.split("::")
|
|
201
|
-
except Exception:
|
|
202
|
-
raise ValueError(f"Invalid openai compatible model ID: {model_id}")
|
|
203
|
-
|
|
204
|
-
openai_compatible_providers = Config.shared().openai_compatible_providers or []
|
|
205
|
-
provider = next(
|
|
206
|
-
filter(
|
|
207
|
-
lambda p: p.get("name") == openai_provider_name, openai_compatible_providers
|
|
208
|
-
),
|
|
209
|
-
None,
|
|
210
|
-
)
|
|
211
|
-
if provider is None:
|
|
212
|
-
raise ValueError(f"OpenAI compatible provider {openai_provider_name} not found")
|
|
213
|
-
|
|
214
|
-
# API key optional - some providers like Ollama don't use it, but LiteLLM errors without one
|
|
215
|
-
api_key = provider.get("api_key") or "NA"
|
|
216
|
-
base_url = provider.get("base_url")
|
|
217
|
-
if base_url is None:
|
|
218
|
-
raise ValueError(
|
|
219
|
-
f"OpenAI compatible provider {openai_provider_name} has no base URL"
|
|
220
|
-
)
|
|
221
|
-
|
|
222
|
-
# Update a copy of the run config properties to use the openai compatible provider
|
|
223
|
-
updated_run_config_properties = run_config_properties.model_copy(deep=True)
|
|
224
|
-
updated_run_config_properties.model_provider_name = (
|
|
225
|
-
ModelProviderName.openai_compatible
|
|
226
|
-
)
|
|
227
|
-
updated_run_config_properties.model_name = model_id
|
|
228
|
-
|
|
229
|
-
return LiteLlmConfig(
|
|
230
|
-
# OpenAI compatible, with a custom base URL
|
|
231
|
-
run_config_properties=updated_run_config_properties,
|
|
232
|
-
base_url=base_url,
|
|
233
|
-
additional_body_options={
|
|
234
|
-
"api_key": api_key,
|
|
235
|
-
},
|
|
236
|
-
)
|
|
237
|
-
|
|
238
|
-
|
|
239
196
|
def lite_llm_provider_model(
|
|
240
197
|
model_id: str,
|
|
241
198
|
) -> KilnModelProvider:
|
|
@@ -458,3 +415,190 @@ provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
|
|
|
458
415
|
message="Attempted to use Cerebras without an API key set. \nGet your API key from https://cloud.cerebras.ai/platform",
|
|
459
416
|
),
|
|
460
417
|
}
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
class LiteLlmCoreConfig(BaseModel):
|
|
421
|
+
base_url: str | None = None
|
|
422
|
+
default_headers: Dict[str, str] | None = None
|
|
423
|
+
additional_body_options: Dict[str, str] | None = None
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def lite_llm_core_config_for_provider(
|
|
427
|
+
provider_name: ModelProviderName,
|
|
428
|
+
openai_compatible_provider_name: str | None = None,
|
|
429
|
+
) -> LiteLlmCoreConfig | None:
|
|
430
|
+
"""
|
|
431
|
+
Returns a LiteLLM core config for a given provider.
|
|
432
|
+
|
|
433
|
+
Args:
|
|
434
|
+
provider_name: The provider to get the config for
|
|
435
|
+
openai_compatible_provider_name: Required for openai compatible providers, this is the name of the underlying provider
|
|
436
|
+
"""
|
|
437
|
+
match provider_name:
|
|
438
|
+
case ModelProviderName.openrouter:
|
|
439
|
+
return LiteLlmCoreConfig(
|
|
440
|
+
base_url=(
|
|
441
|
+
os.getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
|
|
442
|
+
),
|
|
443
|
+
default_headers={
|
|
444
|
+
"HTTP-Referer": "https://kiln.tech/openrouter",
|
|
445
|
+
"X-Title": "KilnAI",
|
|
446
|
+
},
|
|
447
|
+
additional_body_options={
|
|
448
|
+
"api_key": Config.shared().open_router_api_key,
|
|
449
|
+
},
|
|
450
|
+
)
|
|
451
|
+
case ModelProviderName.siliconflow_cn:
|
|
452
|
+
return LiteLlmCoreConfig(
|
|
453
|
+
base_url=os.getenv("SILICONFLOW_BASE_URL")
|
|
454
|
+
or "https://api.siliconflow.cn/v1",
|
|
455
|
+
default_headers={
|
|
456
|
+
"HTTP-Referer": "https://kiln.tech/siliconflow",
|
|
457
|
+
"X-Title": "KilnAI",
|
|
458
|
+
},
|
|
459
|
+
additional_body_options={
|
|
460
|
+
"api_key": Config.shared().siliconflow_cn_api_key,
|
|
461
|
+
},
|
|
462
|
+
)
|
|
463
|
+
case ModelProviderName.openai:
|
|
464
|
+
return LiteLlmCoreConfig(
|
|
465
|
+
additional_body_options={
|
|
466
|
+
"api_key": Config.shared().open_ai_api_key,
|
|
467
|
+
},
|
|
468
|
+
)
|
|
469
|
+
case ModelProviderName.groq:
|
|
470
|
+
return LiteLlmCoreConfig(
|
|
471
|
+
additional_body_options={
|
|
472
|
+
"api_key": Config.shared().groq_api_key,
|
|
473
|
+
},
|
|
474
|
+
)
|
|
475
|
+
case ModelProviderName.amazon_bedrock:
|
|
476
|
+
return LiteLlmCoreConfig(
|
|
477
|
+
additional_body_options={
|
|
478
|
+
"aws_access_key_id": Config.shared().bedrock_access_key,
|
|
479
|
+
"aws_secret_access_key": Config.shared().bedrock_secret_key,
|
|
480
|
+
# The only region that's widely supported for bedrock
|
|
481
|
+
"aws_region_name": "us-west-2",
|
|
482
|
+
},
|
|
483
|
+
)
|
|
484
|
+
case ModelProviderName.ollama:
|
|
485
|
+
# Set the Ollama base URL for 2 reasons:
|
|
486
|
+
# 1. To use the correct base URL
|
|
487
|
+
# 2. We use Ollama's OpenAI compatible API (/v1), and don't just let litellm use the Ollama API. We use more advanced features like json_schema.
|
|
488
|
+
ollama_base_url = (
|
|
489
|
+
Config.shared().ollama_base_url or "http://localhost:11434"
|
|
490
|
+
)
|
|
491
|
+
return LiteLlmCoreConfig(
|
|
492
|
+
base_url=ollama_base_url + "/v1",
|
|
493
|
+
additional_body_options={
|
|
494
|
+
# LiteLLM errors without an api_key, even though Ollama doesn't support one
|
|
495
|
+
"api_key": "NA",
|
|
496
|
+
},
|
|
497
|
+
)
|
|
498
|
+
case ModelProviderName.docker_model_runner:
|
|
499
|
+
docker_base_url = (
|
|
500
|
+
Config.shared().docker_model_runner_base_url
|
|
501
|
+
or "http://localhost:12434/engines/llama.cpp"
|
|
502
|
+
)
|
|
503
|
+
return LiteLlmCoreConfig(
|
|
504
|
+
# Docker Model Runner uses OpenAI-compatible API at /v1 endpoint
|
|
505
|
+
base_url=docker_base_url + "/v1",
|
|
506
|
+
additional_body_options={
|
|
507
|
+
# LiteLLM errors without an api_key, even though Docker Model Runner doesn't require one.
|
|
508
|
+
"api_key": "DMR",
|
|
509
|
+
},
|
|
510
|
+
)
|
|
511
|
+
case ModelProviderName.fireworks_ai:
|
|
512
|
+
return LiteLlmCoreConfig(
|
|
513
|
+
additional_body_options={
|
|
514
|
+
"api_key": Config.shared().fireworks_api_key,
|
|
515
|
+
},
|
|
516
|
+
)
|
|
517
|
+
case ModelProviderName.anthropic:
|
|
518
|
+
return LiteLlmCoreConfig(
|
|
519
|
+
additional_body_options={
|
|
520
|
+
"api_key": Config.shared().anthropic_api_key,
|
|
521
|
+
},
|
|
522
|
+
)
|
|
523
|
+
case ModelProviderName.gemini_api:
|
|
524
|
+
return LiteLlmCoreConfig(
|
|
525
|
+
additional_body_options={
|
|
526
|
+
"api_key": Config.shared().gemini_api_key,
|
|
527
|
+
},
|
|
528
|
+
)
|
|
529
|
+
case ModelProviderName.vertex:
|
|
530
|
+
return LiteLlmCoreConfig(
|
|
531
|
+
additional_body_options={
|
|
532
|
+
"vertex_project": Config.shared().vertex_project_id,
|
|
533
|
+
"vertex_location": Config.shared().vertex_location,
|
|
534
|
+
},
|
|
535
|
+
)
|
|
536
|
+
case ModelProviderName.together_ai:
|
|
537
|
+
return LiteLlmCoreConfig(
|
|
538
|
+
additional_body_options={
|
|
539
|
+
"api_key": Config.shared().together_api_key,
|
|
540
|
+
},
|
|
541
|
+
)
|
|
542
|
+
case ModelProviderName.azure_openai:
|
|
543
|
+
return LiteLlmCoreConfig(
|
|
544
|
+
base_url=Config.shared().azure_openai_endpoint,
|
|
545
|
+
additional_body_options={
|
|
546
|
+
"api_key": Config.shared().azure_openai_api_key,
|
|
547
|
+
"api_version": "2025-02-01-preview",
|
|
548
|
+
},
|
|
549
|
+
)
|
|
550
|
+
case ModelProviderName.huggingface:
|
|
551
|
+
return LiteLlmCoreConfig(
|
|
552
|
+
additional_body_options={
|
|
553
|
+
"api_key": Config.shared().huggingface_api_key,
|
|
554
|
+
},
|
|
555
|
+
)
|
|
556
|
+
case ModelProviderName.cerebras:
|
|
557
|
+
return LiteLlmCoreConfig(
|
|
558
|
+
additional_body_options={
|
|
559
|
+
"api_key": Config.shared().cerebras_api_key,
|
|
560
|
+
},
|
|
561
|
+
)
|
|
562
|
+
case ModelProviderName.openai_compatible:
|
|
563
|
+
# openai compatible requires a model name in the format "provider::model_name"
|
|
564
|
+
if openai_compatible_provider_name is None:
|
|
565
|
+
raise ValueError("OpenAI compatible provider requires a provider name")
|
|
566
|
+
|
|
567
|
+
openai_compatible_providers = (
|
|
568
|
+
Config.shared().openai_compatible_providers or []
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
provider = next(
|
|
572
|
+
filter(
|
|
573
|
+
lambda p: p.get("name") == openai_compatible_provider_name,
|
|
574
|
+
openai_compatible_providers,
|
|
575
|
+
),
|
|
576
|
+
None,
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
if provider is None:
|
|
580
|
+
raise ValueError(
|
|
581
|
+
f"OpenAI compatible provider {openai_compatible_provider_name} not found"
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
# API key optional - some providers like Ollama don't use it, but LiteLLM errors without one
|
|
585
|
+
api_key = provider.get("api_key") or "NA"
|
|
586
|
+
base_url = provider.get("base_url")
|
|
587
|
+
if base_url is None:
|
|
588
|
+
raise ValueError(
|
|
589
|
+
f"OpenAI compatible provider {openai_compatible_provider_name} has no base URL"
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
return LiteLlmCoreConfig(
|
|
593
|
+
base_url=base_url,
|
|
594
|
+
additional_body_options={
|
|
595
|
+
"api_key": api_key,
|
|
596
|
+
},
|
|
597
|
+
)
|
|
598
|
+
# These are virtual providers that should have mapped to an actual provider upstream (using core_provider method)
|
|
599
|
+
case ModelProviderName.kiln_fine_tune:
|
|
600
|
+
return None
|
|
601
|
+
case ModelProviderName.kiln_custom_registry:
|
|
602
|
+
return None
|
|
603
|
+
case _:
|
|
604
|
+
raise_exhaustive_enum_error(provider_name)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import DefaultDict
|
|
3
|
+
|
|
4
|
+
from kiln_ai.datamodel.chunk import ChunkedDocument
|
|
5
|
+
from kiln_ai.datamodel.embedding import ChunkEmbeddings
|
|
6
|
+
from kiln_ai.datamodel.extraction import Document, Extraction
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def deduplicate_extractions(items: list[Extraction]) -> list[Extraction]:
|
|
10
|
+
grouped_items: DefaultDict[str, list[Extraction]] = defaultdict(list)
|
|
11
|
+
for item in items:
|
|
12
|
+
if item.extractor_config_id is None:
|
|
13
|
+
raise ValueError("Extractor config ID is required")
|
|
14
|
+
grouped_items[item.extractor_config_id].append(item)
|
|
15
|
+
return [min(group, key=lambda x: x.created_at) for group in grouped_items.values()]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def deduplicate_chunked_documents(
|
|
19
|
+
items: list[ChunkedDocument],
|
|
20
|
+
) -> list[ChunkedDocument]:
|
|
21
|
+
grouped_items: DefaultDict[str, list[ChunkedDocument]] = defaultdict(list)
|
|
22
|
+
for item in items:
|
|
23
|
+
if item.chunker_config_id is None:
|
|
24
|
+
raise ValueError("Chunker config ID is required")
|
|
25
|
+
grouped_items[item.chunker_config_id].append(item)
|
|
26
|
+
return [min(group, key=lambda x: x.created_at) for group in grouped_items.values()]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def deduplicate_chunk_embeddings(items: list[ChunkEmbeddings]) -> list[ChunkEmbeddings]:
|
|
30
|
+
grouped_items: DefaultDict[str, list[ChunkEmbeddings]] = defaultdict(list)
|
|
31
|
+
for item in items:
|
|
32
|
+
if item.embedding_config_id is None:
|
|
33
|
+
raise ValueError("Embedding config ID is required")
|
|
34
|
+
grouped_items[item.embedding_config_id].append(item)
|
|
35
|
+
return [min(group, key=lambda x: x.created_at) for group in grouped_items.values()]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def filter_documents_by_tags(
|
|
39
|
+
documents: list[Document], tags: list[str] | None
|
|
40
|
+
) -> list[Document]:
|
|
41
|
+
if not tags:
|
|
42
|
+
return documents
|
|
43
|
+
|
|
44
|
+
filtered_documents = []
|
|
45
|
+
for document in documents:
|
|
46
|
+
if document.tags and any(tag in document.tags for tag in tags):
|
|
47
|
+
filtered_documents.append(document)
|
|
48
|
+
|
|
49
|
+
return filtered_documents
|