kiln-ai 0.12.0__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +4 -0
- kiln_ai/adapters/adapter_registry.py +153 -28
- kiln_ai/adapters/eval/__init__.py +28 -0
- kiln_ai/adapters/eval/eval_runner.py +4 -1
- kiln_ai/adapters/eval/g_eval.py +2 -1
- kiln_ai/adapters/eval/test_base_eval.py +1 -0
- kiln_ai/adapters/eval/test_eval_runner.py +1 -0
- kiln_ai/adapters/eval/test_g_eval.py +1 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/test_together_finetune.py +531 -0
- kiln_ai/adapters/fine_tune/together_finetune.py +325 -0
- kiln_ai/adapters/ml_model_list.py +638 -155
- kiln_ai/adapters/model_adapters/__init__.py +2 -4
- kiln_ai/adapters/model_adapters/base_adapter.py +14 -11
- kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
- kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
- kiln_ai/adapters/model_adapters/test_structured_output.py +23 -5
- kiln_ai/adapters/ollama_tools.py +3 -2
- kiln_ai/adapters/parsers/r1_parser.py +19 -14
- kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
- kiln_ai/adapters/provider_tools.py +50 -58
- kiln_ai/adapters/repair/test_repair_task.py +3 -3
- kiln_ai/adapters/run_output.py +1 -1
- kiln_ai/adapters/test_adapter_registry.py +17 -20
- kiln_ai/adapters/test_generate_docs.py +2 -2
- kiln_ai/adapters/test_prompt_adaptors.py +30 -19
- kiln_ai/adapters/test_provider_tools.py +26 -81
- kiln_ai/datamodel/basemodel.py +2 -0
- kiln_ai/datamodel/datamodel_enums.py +2 -0
- kiln_ai/datamodel/json_schema.py +1 -1
- kiln_ai/datamodel/task_output.py +13 -6
- kiln_ai/datamodel/test_basemodel.py +9 -0
- kiln_ai/datamodel/test_datasource.py +19 -0
- kiln_ai/utils/config.py +37 -0
- kiln_ai/utils/dataset_import.py +232 -0
- kiln_ai/utils/test_dataset_import.py +596 -0
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/METADATA +51 -7
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/RECORD +42 -39
- kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -309
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -10
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -289
- kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -343
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -216
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -9,8 +9,8 @@ from kiln_ai.adapters.ml_model_list import (
|
|
|
9
9
|
StructuredOutputMode,
|
|
10
10
|
built_in_models,
|
|
11
11
|
)
|
|
12
|
-
from kiln_ai.adapters.model_adapters.
|
|
13
|
-
|
|
12
|
+
from kiln_ai.adapters.model_adapters.litellm_config import (
|
|
13
|
+
LiteLlmConfig,
|
|
14
14
|
)
|
|
15
15
|
from kiln_ai.adapters.ollama_tools import (
|
|
16
16
|
get_ollama_connection,
|
|
@@ -153,7 +153,7 @@ def kiln_model_provider_from(
|
|
|
153
153
|
return finetune_provider_model(name)
|
|
154
154
|
|
|
155
155
|
if provider_name == ModelProviderName.openai_compatible:
|
|
156
|
-
return
|
|
156
|
+
return lite_llm_provider_model(name)
|
|
157
157
|
|
|
158
158
|
built_in_model = builtin_model_from(name, provider_name)
|
|
159
159
|
if built_in_model:
|
|
@@ -175,13 +175,13 @@ def kiln_model_provider_from(
|
|
|
175
175
|
supports_structured_output=False,
|
|
176
176
|
supports_data_gen=False,
|
|
177
177
|
untested_model=True,
|
|
178
|
-
|
|
178
|
+
model_id=name,
|
|
179
179
|
)
|
|
180
180
|
|
|
181
181
|
|
|
182
|
-
def
|
|
182
|
+
def lite_llm_config(
|
|
183
183
|
model_id: str,
|
|
184
|
-
) ->
|
|
184
|
+
) -> LiteLlmConfig:
|
|
185
185
|
try:
|
|
186
186
|
openai_provider_name, model_id = model_id.split("::")
|
|
187
187
|
except Exception:
|
|
@@ -205,22 +205,23 @@ def openai_compatible_config(
|
|
|
205
205
|
f"OpenAI compatible provider {openai_provider_name} has no base URL"
|
|
206
206
|
)
|
|
207
207
|
|
|
208
|
-
return
|
|
209
|
-
|
|
208
|
+
return LiteLlmConfig(
|
|
209
|
+
# OpenAI compatible, with a custom base URL
|
|
210
210
|
model_name=model_id,
|
|
211
211
|
provider_name=ModelProviderName.openai_compatible,
|
|
212
212
|
base_url=base_url,
|
|
213
|
+
additional_body_options={
|
|
214
|
+
"api_key": api_key,
|
|
215
|
+
},
|
|
213
216
|
)
|
|
214
217
|
|
|
215
218
|
|
|
216
|
-
def
|
|
219
|
+
def lite_llm_provider_model(
|
|
217
220
|
model_id: str,
|
|
218
221
|
) -> KilnModelProvider:
|
|
219
222
|
return KilnModelProvider(
|
|
220
223
|
name=ModelProviderName.openai_compatible,
|
|
221
|
-
|
|
222
|
-
"model": model_id,
|
|
223
|
-
},
|
|
224
|
+
model_id=model_id,
|
|
224
225
|
supports_structured_output=False,
|
|
225
226
|
supports_data_gen=False,
|
|
226
227
|
untested_model=True,
|
|
@@ -264,9 +265,7 @@ def finetune_provider_model(
|
|
|
264
265
|
provider = ModelProviderName[fine_tune.provider]
|
|
265
266
|
model_provider = KilnModelProvider(
|
|
266
267
|
name=provider,
|
|
267
|
-
|
|
268
|
-
"model": fine_tune.fine_tune_model_id,
|
|
269
|
-
},
|
|
268
|
+
model_id=fine_tune.fine_tune_model_id,
|
|
270
269
|
)
|
|
271
270
|
|
|
272
271
|
if fine_tune.structured_output_mode is not None:
|
|
@@ -331,6 +330,18 @@ def provider_name_from_id(id: str) -> str:
|
|
|
331
330
|
return "Custom Models"
|
|
332
331
|
case ModelProviderName.openai_compatible:
|
|
333
332
|
return "OpenAI Compatible"
|
|
333
|
+
case ModelProviderName.azure_openai:
|
|
334
|
+
return "Azure OpenAI"
|
|
335
|
+
case ModelProviderName.gemini_api:
|
|
336
|
+
return "Gemini API"
|
|
337
|
+
case ModelProviderName.anthropic:
|
|
338
|
+
return "Anthropic"
|
|
339
|
+
case ModelProviderName.huggingface:
|
|
340
|
+
return "Hugging Face"
|
|
341
|
+
case ModelProviderName.vertex:
|
|
342
|
+
return "Google Vertex AI"
|
|
343
|
+
case ModelProviderName.together_ai:
|
|
344
|
+
return "Together AI"
|
|
334
345
|
case _:
|
|
335
346
|
# triggers pyright warning if I miss a case
|
|
336
347
|
raise_exhaustive_enum_error(enum_id)
|
|
@@ -338,49 +349,6 @@ def provider_name_from_id(id: str) -> str:
|
|
|
338
349
|
return "Unknown provider: " + id
|
|
339
350
|
|
|
340
351
|
|
|
341
|
-
def provider_options_for_custom_model(
|
|
342
|
-
model_name: str, provider_name: str
|
|
343
|
-
) -> Dict[str, str]:
|
|
344
|
-
"""
|
|
345
|
-
Generated model provider options for a custom model. Each has their own format/options.
|
|
346
|
-
"""
|
|
347
|
-
|
|
348
|
-
if provider_name not in ModelProviderName.__members__:
|
|
349
|
-
raise ValueError(f"Invalid provider name: {provider_name}")
|
|
350
|
-
|
|
351
|
-
enum_id = ModelProviderName(provider_name)
|
|
352
|
-
match enum_id:
|
|
353
|
-
case ModelProviderName.amazon_bedrock:
|
|
354
|
-
# us-west-2 is the only region consistently supported by Bedrock
|
|
355
|
-
return {"model": model_name, "region_name": "us-west-2"}
|
|
356
|
-
case (
|
|
357
|
-
ModelProviderName.openai
|
|
358
|
-
| ModelProviderName.ollama
|
|
359
|
-
| ModelProviderName.fireworks_ai
|
|
360
|
-
| ModelProviderName.openrouter
|
|
361
|
-
| ModelProviderName.groq
|
|
362
|
-
):
|
|
363
|
-
return {"model": model_name}
|
|
364
|
-
case ModelProviderName.kiln_custom_registry:
|
|
365
|
-
raise ValueError(
|
|
366
|
-
"Custom models from registry should be parsed into provider/model before calling this."
|
|
367
|
-
)
|
|
368
|
-
case ModelProviderName.kiln_fine_tune:
|
|
369
|
-
raise ValueError(
|
|
370
|
-
"Fine tuned models should populate provider options via another path"
|
|
371
|
-
)
|
|
372
|
-
case ModelProviderName.openai_compatible:
|
|
373
|
-
raise ValueError(
|
|
374
|
-
"OpenAI compatible models should populate provider options via another path"
|
|
375
|
-
)
|
|
376
|
-
case _:
|
|
377
|
-
# triggers pyright warning if I miss a case
|
|
378
|
-
raise_exhaustive_enum_error(enum_id)
|
|
379
|
-
|
|
380
|
-
# Won't reach this, type checking will catch missed values
|
|
381
|
-
return {"model": model_name}
|
|
382
|
-
|
|
383
|
-
|
|
384
352
|
@dataclass
|
|
385
353
|
class ModelProviderWarning:
|
|
386
354
|
required_config_keys: List[str]
|
|
@@ -408,4 +376,28 @@ provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
|
|
|
408
376
|
required_config_keys=["fireworks_api_key", "fireworks_account_id"],
|
|
409
377
|
message="Attempted to use Fireworks without an API key and account ID set. \nGet your API key from https://fireworks.ai/account/api-keys and your account ID from https://fireworks.ai/account/profile",
|
|
410
378
|
),
|
|
379
|
+
ModelProviderName.anthropic: ModelProviderWarning(
|
|
380
|
+
required_config_keys=["anthropic_api_key"],
|
|
381
|
+
message="Attempted to use Anthropic without an API key set. \nGet your API key from https://console.anthropic.com/settings/keys",
|
|
382
|
+
),
|
|
383
|
+
ModelProviderName.gemini_api: ModelProviderWarning(
|
|
384
|
+
required_config_keys=["gemini_api_key"],
|
|
385
|
+
message="Attempted to use Gemini without an API key set. \nGet your API key from https://aistudio.google.com/app/apikey",
|
|
386
|
+
),
|
|
387
|
+
ModelProviderName.azure_openai: ModelProviderWarning(
|
|
388
|
+
required_config_keys=["azure_openai_api_key", "azure_openai_endpoint"],
|
|
389
|
+
message="Attempted to use Azure OpenAI without an API key and endpoint set. Configure these in settings.",
|
|
390
|
+
),
|
|
391
|
+
ModelProviderName.huggingface: ModelProviderWarning(
|
|
392
|
+
required_config_keys=["huggingface_api_key"],
|
|
393
|
+
message="Attempted to use Hugging Face without an API key set. \nGet your API key from https://huggingface.co/settings/tokens",
|
|
394
|
+
),
|
|
395
|
+
ModelProviderName.vertex: ModelProviderWarning(
|
|
396
|
+
required_config_keys=["vertex_project_id"],
|
|
397
|
+
message="Attempted to use Vertex without a project ID set. \nGet your project ID from the Vertex AI console.",
|
|
398
|
+
),
|
|
399
|
+
ModelProviderName.together_ai: ModelProviderWarning(
|
|
400
|
+
required_config_keys=["together_api_key"],
|
|
401
|
+
message="Attempted to use Together without an API key set. \nGet your API key from https://together.ai/settings/keys",
|
|
402
|
+
),
|
|
411
403
|
}
|
|
@@ -7,7 +7,7 @@ from pydantic import ValidationError
|
|
|
7
7
|
|
|
8
8
|
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
9
9
|
from kiln_ai.adapters.model_adapters.base_adapter import RunOutput
|
|
10
|
-
from kiln_ai.adapters.model_adapters.
|
|
10
|
+
from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter
|
|
11
11
|
from kiln_ai.adapters.repair.repair_task import (
|
|
12
12
|
RepairTaskInput,
|
|
13
13
|
RepairTaskRun,
|
|
@@ -217,7 +217,7 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
|
|
|
217
217
|
"rating": 8,
|
|
218
218
|
}
|
|
219
219
|
|
|
220
|
-
with patch.object(
|
|
220
|
+
with patch.object(LiteLlmAdapter, "_run", new_callable=AsyncMock) as mock_run:
|
|
221
221
|
mock_run.return_value = RunOutput(
|
|
222
222
|
output=mocked_output, intermediate_outputs=None
|
|
223
223
|
)
|
|
@@ -235,7 +235,7 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
|
|
|
235
235
|
parsed_output = json.loads(run.output.output)
|
|
236
236
|
assert parsed_output == mocked_output
|
|
237
237
|
assert run.output.source.properties == {
|
|
238
|
-
"adapter_name": "
|
|
238
|
+
"adapter_name": "kiln_openai_compatible_adapter",
|
|
239
239
|
"model_name": "llama_3_1_8b",
|
|
240
240
|
"model_provider": "ollama",
|
|
241
241
|
"prompt_id": "simple_prompt_builder",
|
kiln_ai/adapters/run_output.py
CHANGED
|
@@ -6,8 +6,7 @@ from kiln_ai import datamodel
|
|
|
6
6
|
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
7
7
|
from kiln_ai.adapters.ml_model_list import ModelProviderName
|
|
8
8
|
from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
|
|
9
|
-
from kiln_ai.adapters.model_adapters.
|
|
10
|
-
from kiln_ai.adapters.model_adapters.openai_model_adapter import OpenAICompatibleAdapter
|
|
9
|
+
from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter
|
|
11
10
|
from kiln_ai.adapters.prompt_builders import BasePromptBuilder
|
|
12
11
|
from kiln_ai.adapters.provider_tools import kiln_model_provider_from
|
|
13
12
|
|
|
@@ -44,9 +43,9 @@ def test_openai_adapter_creation(mock_config, basic_task):
|
|
|
44
43
|
kiln_task=basic_task, model_name="gpt-4", provider=ModelProviderName.openai
|
|
45
44
|
)
|
|
46
45
|
|
|
47
|
-
assert isinstance(adapter,
|
|
46
|
+
assert isinstance(adapter, LiteLlmAdapter)
|
|
48
47
|
assert adapter.config.model_name == "gpt-4"
|
|
49
|
-
assert adapter.config.
|
|
48
|
+
assert adapter.config.additional_body_options == {"api_key": "test-openai-key"}
|
|
50
49
|
assert adapter.config.provider_name == ModelProviderName.openai
|
|
51
50
|
assert adapter.config.base_url is None # OpenAI url is default
|
|
52
51
|
assert adapter.config.default_headers is None
|
|
@@ -59,11 +58,10 @@ def test_openrouter_adapter_creation(mock_config, basic_task):
|
|
|
59
58
|
provider=ModelProviderName.openrouter,
|
|
60
59
|
)
|
|
61
60
|
|
|
62
|
-
assert isinstance(adapter,
|
|
61
|
+
assert isinstance(adapter, LiteLlmAdapter)
|
|
63
62
|
assert adapter.config.model_name == "anthropic/claude-3-opus"
|
|
64
|
-
assert adapter.config.
|
|
63
|
+
assert adapter.config.additional_body_options == {"api_key": "test-openrouter-key"}
|
|
65
64
|
assert adapter.config.provider_name == ModelProviderName.openrouter
|
|
66
|
-
assert adapter.config.base_url == "https://openrouter.ai/api/v1"
|
|
67
65
|
assert adapter.config.default_headers == {
|
|
68
66
|
"HTTP-Referer": "https://getkiln.ai/openrouter",
|
|
69
67
|
"X-Title": "KilnAI",
|
|
@@ -79,12 +77,12 @@ def test_openrouter_adapter_creation(mock_config, basic_task):
|
|
|
79
77
|
ModelProviderName.fireworks_ai,
|
|
80
78
|
],
|
|
81
79
|
)
|
|
82
|
-
def
|
|
80
|
+
def test_openai_compatible_adapter_creation(mock_config, basic_task, provider):
|
|
83
81
|
adapter = adapter_for_task(
|
|
84
82
|
kiln_task=basic_task, model_name="test-model", provider=provider
|
|
85
83
|
)
|
|
86
84
|
|
|
87
|
-
assert isinstance(adapter,
|
|
85
|
+
assert isinstance(adapter, LiteLlmAdapter)
|
|
88
86
|
assert adapter.run_config.model_name == "test-model"
|
|
89
87
|
|
|
90
88
|
|
|
@@ -122,10 +120,12 @@ def test_invalid_provider(mock_config, basic_task):
|
|
|
122
120
|
)
|
|
123
121
|
|
|
124
122
|
|
|
125
|
-
@patch("kiln_ai.adapters.adapter_registry.
|
|
123
|
+
@patch("kiln_ai.adapters.adapter_registry.lite_llm_config")
|
|
126
124
|
def test_openai_compatible_adapter(mock_compatible_config, mock_config, basic_task):
|
|
127
125
|
mock_compatible_config.return_value.model_name = "test-model"
|
|
128
|
-
mock_compatible_config.return_value.
|
|
126
|
+
mock_compatible_config.return_value.additional_body_options = {
|
|
127
|
+
"api_key": "test-key"
|
|
128
|
+
}
|
|
129
129
|
mock_compatible_config.return_value.base_url = "https://test.com/v1"
|
|
130
130
|
mock_compatible_config.return_value.provider_name = "CustomProvider99"
|
|
131
131
|
|
|
@@ -135,12 +135,9 @@ def test_openai_compatible_adapter(mock_compatible_config, mock_config, basic_ta
|
|
|
135
135
|
provider=ModelProviderName.openai_compatible,
|
|
136
136
|
)
|
|
137
137
|
|
|
138
|
-
assert isinstance(adapter,
|
|
138
|
+
assert isinstance(adapter, LiteLlmAdapter)
|
|
139
139
|
mock_compatible_config.assert_called_once_with("provider::test-model")
|
|
140
|
-
assert adapter.config
|
|
141
|
-
assert adapter.config.api_key == "test-key"
|
|
142
|
-
assert adapter.config.base_url == "https://test.com/v1"
|
|
143
|
-
assert adapter.config.provider_name == "CustomProvider99"
|
|
140
|
+
assert adapter.config == mock_compatible_config.return_value
|
|
144
141
|
|
|
145
142
|
|
|
146
143
|
def test_custom_openai_compatible_provider(mock_config, basic_task):
|
|
@@ -150,9 +147,9 @@ def test_custom_openai_compatible_provider(mock_config, basic_task):
|
|
|
150
147
|
provider=ModelProviderName.kiln_custom_registry,
|
|
151
148
|
)
|
|
152
149
|
|
|
153
|
-
assert isinstance(adapter,
|
|
150
|
+
assert isinstance(adapter, LiteLlmAdapter)
|
|
154
151
|
assert adapter.config.model_name == "openai::test-model"
|
|
155
|
-
assert adapter.config.
|
|
152
|
+
assert adapter.config.additional_body_options == {"api_key": "test-openai-key"}
|
|
156
153
|
assert adapter.config.base_url is None # openai is none
|
|
157
154
|
assert adapter.config.provider_name == ModelProviderName.kiln_custom_registry
|
|
158
155
|
|
|
@@ -165,7 +162,7 @@ async def test_fine_tune_provider(mock_config, basic_task, mock_finetune_from_id
|
|
|
165
162
|
)
|
|
166
163
|
|
|
167
164
|
mock_finetune_from_id.assert_called_once_with("proj::task::tune")
|
|
168
|
-
assert isinstance(adapter,
|
|
165
|
+
assert isinstance(adapter, LiteLlmAdapter)
|
|
169
166
|
assert adapter.config.provider_name == ModelProviderName.kiln_fine_tune
|
|
170
167
|
# Kiln model name here, but the underlying openai model id below
|
|
171
168
|
assert adapter.config.model_name == "proj::task::tune"
|
|
@@ -174,4 +171,4 @@ async def test_fine_tune_provider(mock_config, basic_task, mock_finetune_from_id
|
|
|
174
171
|
"proj::task::tune", provider_name=ModelProviderName.kiln_fine_tune
|
|
175
172
|
)
|
|
176
173
|
# The actual model name from the fine tune object
|
|
177
|
-
assert provider.
|
|
174
|
+
assert provider.model_id == "test-model"
|
|
@@ -58,8 +58,8 @@ def test_generate_model_table():
|
|
|
58
58
|
table.append(row)
|
|
59
59
|
|
|
60
60
|
# Print the table (useful for documentation)
|
|
61
|
-
|
|
62
|
-
|
|
61
|
+
print("\nModel Capability Matrix:\n")
|
|
62
|
+
print("\n".join(table))
|
|
63
63
|
|
|
64
64
|
# Basic assertions to ensure the table is well-formed
|
|
65
65
|
assert len(table) > 2, "Table should have header and at least one row"
|
|
@@ -1,13 +1,17 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from pathlib import Path
|
|
3
|
+
from unittest.mock import patch
|
|
3
4
|
|
|
4
5
|
import pytest
|
|
5
|
-
from
|
|
6
|
+
from litellm.utils import ModelResponse
|
|
6
7
|
|
|
7
8
|
import kiln_ai.datamodel as datamodel
|
|
8
9
|
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
9
10
|
from kiln_ai.adapters.ml_model_list import built_in_models
|
|
10
|
-
from kiln_ai.adapters.model_adapters.
|
|
11
|
+
from kiln_ai.adapters.model_adapters.litellm_adapter import (
|
|
12
|
+
LiteLlmAdapter,
|
|
13
|
+
LiteLlmConfig,
|
|
14
|
+
)
|
|
11
15
|
from kiln_ai.adapters.ollama_tools import ollama_online
|
|
12
16
|
from kiln_ai.adapters.prompt_builders import (
|
|
13
17
|
BasePromptBuilder,
|
|
@@ -20,6 +24,9 @@ def get_all_models_and_providers():
|
|
|
20
24
|
model_provider_pairs = []
|
|
21
25
|
for model in built_in_models:
|
|
22
26
|
for provider in model.providers:
|
|
27
|
+
if not provider.model_id:
|
|
28
|
+
# it's possible for models to not have an ID (fine-tune only model)
|
|
29
|
+
continue
|
|
23
30
|
model_provider_pairs.append((model.name, provider.name))
|
|
24
31
|
return model_provider_pairs
|
|
25
32
|
|
|
@@ -106,23 +113,27 @@ async def test_amazon_bedrock(tmp_path):
|
|
|
106
113
|
await run_simple_test(tmp_path, "llama_3_1_8b", "amazon_bedrock")
|
|
107
114
|
|
|
108
115
|
|
|
109
|
-
async def test_mock(tmp_path):
|
|
110
|
-
task = build_test_task(tmp_path)
|
|
111
|
-
mockChatModel = FakeListChatModel(responses=["mock response"])
|
|
112
|
-
adapter = LangchainAdapter(
|
|
113
|
-
task,
|
|
114
|
-
custom_model=mockChatModel,
|
|
115
|
-
provider="ollama",
|
|
116
|
-
)
|
|
117
|
-
run = await adapter.invoke("You are a mock, send me the response!")
|
|
118
|
-
assert "mock response" in run.output.output
|
|
119
|
-
|
|
120
|
-
|
|
121
116
|
async def test_mock_returning_run(tmp_path):
|
|
122
117
|
task = build_test_task(tmp_path)
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
118
|
+
with patch("litellm.acompletion") as mock_acompletion:
|
|
119
|
+
# Configure the mock to return a properly structured response
|
|
120
|
+
mock_acompletion.return_value = ModelResponse(
|
|
121
|
+
model="custom_model",
|
|
122
|
+
choices=[{"message": {"content": "mock response"}}],
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
adapter = LiteLlmAdapter(
|
|
126
|
+
config=LiteLlmConfig(
|
|
127
|
+
model_name="custom_model",
|
|
128
|
+
provider_name="ollama",
|
|
129
|
+
base_url="http://localhost:11434",
|
|
130
|
+
additional_body_options={"api_key": "test_key"},
|
|
131
|
+
),
|
|
132
|
+
kiln_task=task,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
run = await adapter.invoke("You are a mock, send me the response!")
|
|
136
|
+
|
|
126
137
|
assert run.output.output == "mock response"
|
|
127
138
|
assert run is not None
|
|
128
139
|
assert run.id is not None
|
|
@@ -130,8 +141,8 @@ async def test_mock_returning_run(tmp_path):
|
|
|
130
141
|
assert run.output.output == "mock response"
|
|
131
142
|
assert "created_by" in run.input_source.properties
|
|
132
143
|
assert run.output.source.properties == {
|
|
133
|
-
"adapter_name": "
|
|
134
|
-
"model_name": "
|
|
144
|
+
"adapter_name": "kiln_openai_compatible_adapter",
|
|
145
|
+
"model_name": "custom_model",
|
|
135
146
|
"model_provider": "ollama",
|
|
136
147
|
"prompt_id": "simple_prompt_builder",
|
|
137
148
|
}
|
|
@@ -17,12 +17,11 @@ from kiln_ai.adapters.provider_tools import (
|
|
|
17
17
|
finetune_provider_model,
|
|
18
18
|
get_model_and_provider,
|
|
19
19
|
kiln_model_provider_from,
|
|
20
|
-
|
|
21
|
-
|
|
20
|
+
lite_llm_config,
|
|
21
|
+
lite_llm_provider_model,
|
|
22
22
|
parse_custom_model_id,
|
|
23
23
|
provider_enabled,
|
|
24
24
|
provider_name_from_id,
|
|
25
|
-
provider_options_for_custom_model,
|
|
26
25
|
provider_warnings,
|
|
27
26
|
)
|
|
28
27
|
from kiln_ai.datamodel import Finetune, StructuredOutputMode, Task
|
|
@@ -186,7 +185,7 @@ def test_get_model_and_provider_valid():
|
|
|
186
185
|
assert provider is not None
|
|
187
186
|
assert model.name == ModelName.phi_3_5
|
|
188
187
|
assert provider.name == ModelProviderName.ollama
|
|
189
|
-
assert provider.
|
|
188
|
+
assert provider.model_id == "phi3.5"
|
|
190
189
|
|
|
191
190
|
|
|
192
191
|
def test_get_model_and_provider_invalid_model():
|
|
@@ -227,7 +226,7 @@ def test_get_model_and_provider_multiple_providers():
|
|
|
227
226
|
assert provider is not None
|
|
228
227
|
assert model.name == ModelName.llama_3_3_70b
|
|
229
228
|
assert provider.name == ModelProviderName.groq
|
|
230
|
-
assert provider.
|
|
229
|
+
assert provider.model_id == "llama-3.3-70b-versatile"
|
|
231
230
|
|
|
232
231
|
|
|
233
232
|
@pytest.mark.asyncio
|
|
@@ -324,59 +323,7 @@ async def test_kiln_model_provider_from_custom_model_valid(mock_config):
|
|
|
324
323
|
assert provider.supports_structured_output is False
|
|
325
324
|
assert provider.supports_data_gen is False
|
|
326
325
|
assert provider.untested_model is True
|
|
327
|
-
assert "
|
|
328
|
-
assert provider.provider_options["model"] == "custom_model"
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
def test_provider_options_for_custom_model_basic():
|
|
332
|
-
"""Test basic case with custom model name"""
|
|
333
|
-
options = provider_options_for_custom_model(
|
|
334
|
-
"custom_model_name", ModelProviderName.openai
|
|
335
|
-
)
|
|
336
|
-
assert options == {"model": "custom_model_name"}
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
def test_provider_options_for_custom_model_bedrock():
|
|
340
|
-
"""Test Amazon Bedrock provider options"""
|
|
341
|
-
options = provider_options_for_custom_model(
|
|
342
|
-
ModelName.llama_3_1_8b, ModelProviderName.amazon_bedrock
|
|
343
|
-
)
|
|
344
|
-
assert options == {"model": ModelName.llama_3_1_8b, "region_name": "us-west-2"}
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
@pytest.mark.parametrize(
|
|
348
|
-
"provider",
|
|
349
|
-
[
|
|
350
|
-
ModelProviderName.openai,
|
|
351
|
-
ModelProviderName.ollama,
|
|
352
|
-
ModelProviderName.fireworks_ai,
|
|
353
|
-
ModelProviderName.openrouter,
|
|
354
|
-
ModelProviderName.groq,
|
|
355
|
-
],
|
|
356
|
-
)
|
|
357
|
-
def test_provider_options_for_custom_model_simple_providers(provider):
|
|
358
|
-
"""Test providers that just need model name"""
|
|
359
|
-
|
|
360
|
-
options = provider_options_for_custom_model(ModelName.llama_3_1_8b, provider)
|
|
361
|
-
assert options == {"model": ModelName.llama_3_1_8b}
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
def test_provider_options_for_custom_model_kiln_fine_tune():
|
|
365
|
-
"""Test that kiln_fine_tune raises appropriate error"""
|
|
366
|
-
with pytest.raises(ValueError) as exc_info:
|
|
367
|
-
provider_options_for_custom_model(
|
|
368
|
-
"model_name", ModelProviderName.kiln_fine_tune
|
|
369
|
-
)
|
|
370
|
-
assert (
|
|
371
|
-
str(exc_info.value)
|
|
372
|
-
== "Fine tuned models should populate provider options via another path"
|
|
373
|
-
)
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
def test_provider_options_for_custom_model_invalid_enum():
|
|
377
|
-
"""Test handling of invalid enum value"""
|
|
378
|
-
with pytest.raises(ValueError):
|
|
379
|
-
provider_options_for_custom_model("model_name", "invalid_enum_value")
|
|
326
|
+
assert provider.model_id == "custom_model"
|
|
380
327
|
|
|
381
328
|
|
|
382
329
|
@pytest.mark.asyncio
|
|
@@ -393,7 +340,7 @@ async def test_kiln_model_provider_from_custom_registry(mock_config):
|
|
|
393
340
|
assert provider.supports_structured_output is False
|
|
394
341
|
assert provider.supports_data_gen is False
|
|
395
342
|
assert provider.untested_model is True
|
|
396
|
-
assert provider.
|
|
343
|
+
assert provider.model_id == "gpt-4-turbo"
|
|
397
344
|
|
|
398
345
|
|
|
399
346
|
@pytest.mark.asyncio
|
|
@@ -412,7 +359,7 @@ async def test_builtin_model_from_valid_model_default_provider(mock_config):
|
|
|
412
359
|
|
|
413
360
|
assert provider is not None
|
|
414
361
|
assert provider.name == ModelProviderName.ollama
|
|
415
|
-
assert provider.
|
|
362
|
+
assert provider.model_id == "phi3.5"
|
|
416
363
|
|
|
417
364
|
|
|
418
365
|
@pytest.mark.asyncio
|
|
@@ -426,7 +373,7 @@ async def test_builtin_model_from_valid_model_specific_provider(mock_config):
|
|
|
426
373
|
|
|
427
374
|
assert provider is not None
|
|
428
375
|
assert provider.name == ModelProviderName.groq
|
|
429
|
-
assert provider.
|
|
376
|
+
assert provider.model_id == "llama-3.3-70b-versatile"
|
|
430
377
|
|
|
431
378
|
|
|
432
379
|
@pytest.mark.asyncio
|
|
@@ -477,7 +424,7 @@ def test_finetune_provider_model_success(mock_project, mock_task, mock_finetune)
|
|
|
477
424
|
provider = finetune_provider_model(model_id)
|
|
478
425
|
|
|
479
426
|
assert provider.name == ModelProviderName.openai
|
|
480
|
-
assert provider.
|
|
427
|
+
assert provider.model_id == "ft:gpt-3.5-turbo:custom:model-123"
|
|
481
428
|
assert provider.structured_output_mode == StructuredOutputMode.json_schema
|
|
482
429
|
|
|
483
430
|
|
|
@@ -573,7 +520,7 @@ def test_finetune_provider_model_structured_mode(
|
|
|
573
520
|
provider = finetune_provider_model("project-123::task-456::finetune-789")
|
|
574
521
|
|
|
575
522
|
assert provider.name == provider_name
|
|
576
|
-
assert provider.
|
|
523
|
+
assert provider.model_id == "fireworks-model-123"
|
|
577
524
|
assert provider.structured_output_mode == expected_mode
|
|
578
525
|
|
|
579
526
|
|
|
@@ -581,69 +528,67 @@ def test_openai_compatible_provider_config(mock_shared_config):
|
|
|
581
528
|
"""Test successful creation of an OpenAI compatible provider"""
|
|
582
529
|
model_id = "test_provider::gpt-4"
|
|
583
530
|
|
|
584
|
-
config =
|
|
531
|
+
config = lite_llm_config(model_id)
|
|
585
532
|
|
|
586
533
|
assert config.provider_name == ModelProviderName.openai_compatible
|
|
587
534
|
assert config.model_name == "gpt-4"
|
|
588
|
-
assert config.
|
|
535
|
+
assert config.additional_body_options == {"api_key": "test-key"}
|
|
589
536
|
assert config.base_url == "https://api.test.com"
|
|
590
537
|
|
|
591
538
|
|
|
592
|
-
def
|
|
539
|
+
def test_litellm_provider_model_success(mock_shared_config):
|
|
593
540
|
"""Test successful creation of an OpenAI compatible provider"""
|
|
594
541
|
model_id = "test_provider::gpt-4"
|
|
595
542
|
|
|
596
|
-
provider =
|
|
543
|
+
provider = lite_llm_provider_model(model_id)
|
|
597
544
|
|
|
598
545
|
assert provider.name == ModelProviderName.openai_compatible
|
|
599
|
-
assert provider.
|
|
600
|
-
"model": model_id,
|
|
601
|
-
}
|
|
546
|
+
assert provider.model_id == model_id
|
|
602
547
|
assert provider.supports_structured_output is False
|
|
603
548
|
assert provider.supports_data_gen is False
|
|
604
549
|
assert provider.untested_model is True
|
|
605
550
|
|
|
606
551
|
|
|
607
|
-
def
|
|
552
|
+
def test_lite_llm_config_no_api_key(mock_shared_config):
|
|
608
553
|
"""Test provider creation without API key (should work as some providers don't require it)"""
|
|
609
554
|
model_id = "no_key_provider::gpt-4"
|
|
610
555
|
|
|
611
|
-
config =
|
|
556
|
+
config = lite_llm_config(model_id)
|
|
612
557
|
|
|
613
558
|
assert config.provider_name == ModelProviderName.openai_compatible
|
|
614
559
|
assert config.model_name == "gpt-4"
|
|
615
|
-
assert config.api_key
|
|
560
|
+
assert config.additional_body_options == {"api_key": None}
|
|
616
561
|
assert config.base_url == "https://api.nokey.com"
|
|
617
562
|
|
|
618
563
|
|
|
619
|
-
def
|
|
564
|
+
def test_lite_llm_config_invalid_id():
|
|
620
565
|
"""Test handling of invalid model ID format"""
|
|
621
566
|
with pytest.raises(ValueError) as exc_info:
|
|
622
|
-
|
|
567
|
+
lite_llm_config("invalid-id-format")
|
|
623
568
|
assert (
|
|
624
569
|
str(exc_info.value) == "Invalid openai compatible model ID: invalid-id-format"
|
|
625
570
|
)
|
|
626
571
|
|
|
627
572
|
|
|
628
|
-
def
|
|
573
|
+
def test_lite_llm_config_no_providers(mock_shared_config):
|
|
629
574
|
"""Test handling when no providers are configured"""
|
|
630
575
|
mock_shared_config.return_value.openai_compatible_providers = None
|
|
631
576
|
|
|
632
577
|
with pytest.raises(ValueError) as exc_info:
|
|
633
|
-
|
|
578
|
+
lite_llm_config("test_provider::gpt-4")
|
|
634
579
|
assert str(exc_info.value) == "OpenAI compatible provider test_provider not found"
|
|
635
580
|
|
|
636
581
|
|
|
637
|
-
def
|
|
582
|
+
def test_lite_llm_config_provider_not_found(mock_shared_config):
|
|
638
583
|
"""Test handling of non-existent provider"""
|
|
639
584
|
with pytest.raises(ValueError) as exc_info:
|
|
640
|
-
|
|
585
|
+
lite_llm_config("unknown_provider::gpt-4")
|
|
641
586
|
assert (
|
|
642
587
|
str(exc_info.value) == "OpenAI compatible provider unknown_provider not found"
|
|
643
588
|
)
|
|
644
589
|
|
|
645
590
|
|
|
646
|
-
def
|
|
591
|
+
def test_lite_llm_config_no_base_url(mock_shared_config):
|
|
647
592
|
"""Test handling of provider without base URL"""
|
|
648
593
|
mock_shared_config.return_value.openai_compatible_providers = [
|
|
649
594
|
{
|
|
@@ -653,7 +598,7 @@ def test_openai_compatible_config_no_base_url(mock_shared_config):
|
|
|
653
598
|
]
|
|
654
599
|
|
|
655
600
|
with pytest.raises(ValueError) as exc_info:
|
|
656
|
-
|
|
601
|
+
lite_llm_config("test_provider::gpt-4")
|
|
657
602
|
assert (
|
|
658
603
|
str(exc_info.value)
|
|
659
604
|
== "OpenAI compatible provider test_provider has no base URL"
|
kiln_ai/datamodel/basemodel.py
CHANGED
|
@@ -268,6 +268,8 @@ class KilnParentedModel(KilnBaseModel, metaclass=ABCMeta):
|
|
|
268
268
|
)
|
|
269
269
|
if parent_path is None:
|
|
270
270
|
return None
|
|
271
|
+
if not parent_path.exists():
|
|
272
|
+
return None
|
|
271
273
|
loaded_parent = self.__class__.parent_type().load_from_file(parent_path)
|
|
272
274
|
self.parent = loaded_parent
|
|
273
275
|
return loaded_parent
|
|
@@ -30,6 +30,7 @@ class StructuredOutputMode(str, Enum):
|
|
|
30
30
|
- json_mode: request json using API's JSON mode, which should return valid JSON, but isn't checking/passing the schema
|
|
31
31
|
- json_instructions: append instructions to the prompt to request json matching the schema. No API capabilities are used. You should have a custom parser on these models as they will be returning strings.
|
|
32
32
|
- json_instruction_and_object: append instructions to the prompt to request json matching the schema. Also request the response as json_mode via API capabilities (returning dictionaries).
|
|
33
|
+
- json_custom_instructions: The model should output JSON, but custom instructions are already included in the system prompt. Don't append additional JSON instructions.
|
|
33
34
|
"""
|
|
34
35
|
|
|
35
36
|
default = "default"
|
|
@@ -39,6 +40,7 @@ class StructuredOutputMode(str, Enum):
|
|
|
39
40
|
json_mode = "json_mode"
|
|
40
41
|
json_instructions = "json_instructions"
|
|
41
42
|
json_instruction_and_object = "json_instruction_and_object"
|
|
43
|
+
json_custom_instructions = "json_custom_instructions"
|
|
42
44
|
|
|
43
45
|
|
|
44
46
|
class FineTuneStatusType(str, Enum):
|