kiln-ai 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (44) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +19 -0
  3. kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
  4. kiln_ai/adapters/fine_tune/__init__.py +14 -0
  5. kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
  6. kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
  7. kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
  8. kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
  9. kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
  10. kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
  11. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
  12. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
  13. kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
  14. kiln_ai/adapters/langchain_adapters.py +103 -13
  15. kiln_ai/adapters/ml_model_list.py +239 -303
  16. kiln_ai/adapters/ollama_tools.py +115 -0
  17. kiln_ai/adapters/provider_tools.py +308 -0
  18. kiln_ai/adapters/repair/repair_task.py +4 -2
  19. kiln_ai/adapters/repair/test_repair_task.py +6 -11
  20. kiln_ai/adapters/test_langchain_adapter.py +229 -18
  21. kiln_ai/adapters/test_ollama_tools.py +42 -0
  22. kiln_ai/adapters/test_prompt_adaptors.py +7 -5
  23. kiln_ai/adapters/test_provider_tools.py +531 -0
  24. kiln_ai/adapters/test_structured_output.py +22 -43
  25. kiln_ai/datamodel/__init__.py +287 -24
  26. kiln_ai/datamodel/basemodel.py +122 -38
  27. kiln_ai/datamodel/model_cache.py +116 -0
  28. kiln_ai/datamodel/registry.py +31 -0
  29. kiln_ai/datamodel/test_basemodel.py +167 -4
  30. kiln_ai/datamodel/test_dataset_split.py +234 -0
  31. kiln_ai/datamodel/test_example_models.py +12 -0
  32. kiln_ai/datamodel/test_model_cache.py +244 -0
  33. kiln_ai/datamodel/test_models.py +215 -1
  34. kiln_ai/datamodel/test_registry.py +96 -0
  35. kiln_ai/utils/config.py +14 -1
  36. kiln_ai/utils/name_generator.py +125 -0
  37. kiln_ai/utils/test_name_geneator.py +47 -0
  38. kiln_ai-0.7.1.dist-info/METADATA +237 -0
  39. kiln_ai-0.7.1.dist-info/RECORD +58 -0
  40. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/WHEEL +1 -1
  41. kiln_ai/adapters/test_ml_model_list.py +0 -181
  42. kiln_ai-0.6.1.dist-info/METADATA +0 -88
  43. kiln_ai-0.6.1.dist-info/RECORD +0 -37
  44. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,115 @@
1
+ import os
2
+ from typing import Any, List
3
+
4
+ import httpx
5
+ import requests
6
+ from pydantic import BaseModel, Field
7
+
8
+ from kiln_ai.adapters.ml_model_list import ModelProviderName, built_in_models
9
+ from kiln_ai.utils.config import Config
10
+
11
+
12
+ def ollama_base_url() -> str:
13
+ """
14
+ Gets the base URL for Ollama API connections.
15
+
16
+ Returns:
17
+ The base URL to use for Ollama API calls, using environment variable if set
18
+ or falling back to localhost default
19
+ """
20
+ config_base_url = Config.shared().ollama_base_url
21
+ if config_base_url:
22
+ return config_base_url
23
+ return "http://localhost:11434"
24
+
25
+
26
+ async def ollama_online() -> bool:
27
+ """
28
+ Checks if the Ollama service is available and responding.
29
+
30
+ Returns:
31
+ True if Ollama is available and responding, False otherwise
32
+ """
33
+ try:
34
+ httpx.get(ollama_base_url() + "/api/tags")
35
+ except httpx.RequestError:
36
+ return False
37
+ return True
38
+
39
+
40
+ class OllamaConnection(BaseModel):
41
+ message: str
42
+ supported_models: List[str]
43
+ untested_models: List[str] = Field(default_factory=list)
44
+
45
+ def all_models(self) -> List[str]:
46
+ return self.supported_models + self.untested_models
47
+
48
+
49
+ # Parse the Ollama /api/tags response
50
+ def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
51
+ # Build a list of models we support for Ollama from the built-in model list
52
+ supported_ollama_models = [
53
+ provider.provider_options["model"]
54
+ for model in built_in_models
55
+ for provider in model.providers
56
+ if provider.name == ModelProviderName.ollama
57
+ ]
58
+ # Append model_aliases to supported_ollama_models
59
+ supported_ollama_models.extend(
60
+ [
61
+ alias
62
+ for model in built_in_models
63
+ for provider in model.providers
64
+ for alias in provider.provider_options.get("model_aliases", [])
65
+ ]
66
+ )
67
+
68
+ if "models" in tags:
69
+ models = tags["models"]
70
+ if isinstance(models, list):
71
+ model_names = [model["model"] for model in models]
72
+ available_supported_models = []
73
+ untested_models = []
74
+ supported_models_latest_aliases = [
75
+ f"{m}:latest" for m in supported_ollama_models
76
+ ]
77
+ for model in model_names:
78
+ if (
79
+ model in supported_ollama_models
80
+ or model in supported_models_latest_aliases
81
+ ):
82
+ available_supported_models.append(model)
83
+ else:
84
+ untested_models.append(model)
85
+
86
+ if available_supported_models or untested_models:
87
+ return OllamaConnection(
88
+ message="Ollama connected",
89
+ supported_models=available_supported_models,
90
+ untested_models=untested_models,
91
+ )
92
+
93
+ return OllamaConnection(
94
+ message="Ollama is running, but no supported models are installed. Install one or more supported model, like 'ollama pull phi3.5'.",
95
+ supported_models=[],
96
+ untested_models=[],
97
+ )
98
+
99
+
100
+ async def get_ollama_connection() -> OllamaConnection | None:
101
+ """
102
+ Gets the connection status for Ollama.
103
+ """
104
+ try:
105
+ tags = requests.get(ollama_base_url() + "/api/tags", timeout=5).json()
106
+
107
+ except Exception:
108
+ return None
109
+
110
+ return parse_ollama_tags(tags)
111
+
112
+
113
+ def ollama_model_installed(conn: OllamaConnection, model_name: str) -> bool:
114
+ all_models = conn.all_models()
115
+ return model_name in all_models or f"{model_name}:latest" in all_models
@@ -0,0 +1,308 @@
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, NoReturn
3
+
4
+ from kiln_ai.adapters.ml_model_list import (
5
+ KilnModel,
6
+ KilnModelProvider,
7
+ ModelName,
8
+ ModelProviderName,
9
+ built_in_models,
10
+ )
11
+ from kiln_ai.adapters.ollama_tools import (
12
+ get_ollama_connection,
13
+ )
14
+ from kiln_ai.datamodel import Finetune, Task
15
+ from kiln_ai.datamodel.registry import project_from_id
16
+
17
+ from ..utils.config import Config
18
+
19
+
20
+ async def provider_enabled(provider_name: ModelProviderName) -> bool:
21
+ if provider_name == ModelProviderName.ollama:
22
+ try:
23
+ conn = await get_ollama_connection()
24
+ return conn is not None and (
25
+ len(conn.supported_models) > 0 or len(conn.untested_models) > 0
26
+ )
27
+ except Exception:
28
+ return False
29
+
30
+ provider_warning = provider_warnings.get(provider_name)
31
+ if provider_warning is None:
32
+ return False
33
+ for required_key in provider_warning.required_config_keys:
34
+ if get_config_value(required_key) is None:
35
+ return False
36
+ return True
37
+
38
+
39
+ def get_config_value(key: str):
40
+ try:
41
+ return Config.shared().__getattr__(key)
42
+ except AttributeError:
43
+ return None
44
+
45
+
46
+ def check_provider_warnings(provider_name: ModelProviderName):
47
+ """
48
+ Validates that required configuration is present for a given provider.
49
+
50
+ Args:
51
+ provider_name: The provider to check
52
+
53
+ Raises:
54
+ ValueError: If required configuration keys are missing
55
+ """
56
+ warning_check = provider_warnings.get(provider_name)
57
+ if warning_check is None:
58
+ return
59
+ for key in warning_check.required_config_keys:
60
+ if get_config_value(key) is None:
61
+ raise ValueError(warning_check.message)
62
+
63
+
64
+ async def builtin_model_from(
65
+ name: str, provider_name: str | None = None
66
+ ) -> KilnModelProvider | None:
67
+ """
68
+ Gets a model and provider from the built-in list of models.
69
+
70
+ Args:
71
+ name: The name of the model to get
72
+ provider_name: Optional specific provider to use (defaults to first available)
73
+
74
+ Returns:
75
+ A tuple of (provider, model)
76
+
77
+ Raises:
78
+ ValueError: If the model or provider is not found, or if the provider is misconfigured
79
+ """
80
+ if name not in ModelName.__members__:
81
+ return None
82
+
83
+ # Select the model from built_in_models using the name
84
+ model = next(filter(lambda m: m.name == name, built_in_models))
85
+ if model is None:
86
+ raise ValueError(f"Model {name} not found")
87
+
88
+ # If a provider is provided, select the provider from the model's provider_config
89
+ provider: KilnModelProvider | None = None
90
+ if model.providers is None or len(model.providers) == 0:
91
+ raise ValueError(f"Model {name} has no providers")
92
+ elif provider_name is None:
93
+ provider = model.providers[0]
94
+ else:
95
+ provider = next(
96
+ filter(lambda p: p.name == provider_name, model.providers), None
97
+ )
98
+ if provider is None:
99
+ return None
100
+
101
+ check_provider_warnings(provider.name)
102
+ return provider
103
+
104
+
105
+ async def kiln_model_provider_from(
106
+ name: str, provider_name: str | None = None
107
+ ) -> KilnModelProvider:
108
+ if provider_name == ModelProviderName.kiln_fine_tune:
109
+ return finetune_provider_model(name)
110
+
111
+ built_in_model = await builtin_model_from(name, provider_name)
112
+ if built_in_model:
113
+ return built_in_model
114
+
115
+ # For custom registry, get the provider name and model name from the model id
116
+ if provider_name == ModelProviderName.kiln_custom_registry:
117
+ provider_name = name.split("::", 1)[0]
118
+ name = name.split("::", 1)[1]
119
+
120
+ # Custom/untested model. Set untested, and build a ModelProvider at runtime
121
+ if provider_name is None:
122
+ raise ValueError("Provider name is required for custom models")
123
+ if provider_name not in ModelProviderName.__members__:
124
+ raise ValueError(f"Invalid provider name: {provider_name}")
125
+ provider = ModelProviderName(provider_name)
126
+ check_provider_warnings(provider)
127
+ return KilnModelProvider(
128
+ name=provider,
129
+ supports_structured_output=False,
130
+ supports_data_gen=False,
131
+ untested_model=True,
132
+ provider_options=provider_options_for_custom_model(name, provider_name),
133
+ )
134
+
135
+
136
+ finetune_cache: dict[str, KilnModelProvider] = {}
137
+
138
+
139
+ def finetune_provider_model(
140
+ model_id: str,
141
+ ) -> KilnModelProvider:
142
+ if model_id in finetune_cache:
143
+ return finetune_cache[model_id]
144
+
145
+ try:
146
+ project_id, task_id, fine_tune_id = model_id.split("::")
147
+ except Exception:
148
+ raise ValueError(f"Invalid fine tune ID: {model_id}")
149
+ project = project_from_id(project_id)
150
+ if project is None:
151
+ raise ValueError(f"Project {project_id} not found")
152
+ task = Task.from_id_and_parent_path(task_id, project.path)
153
+ if task is None:
154
+ raise ValueError(f"Task {task_id} not found")
155
+ fine_tune = Finetune.from_id_and_parent_path(fine_tune_id, task.path)
156
+ if fine_tune is None:
157
+ raise ValueError(f"Fine tune {fine_tune_id} not found")
158
+ if fine_tune.fine_tune_model_id is None:
159
+ raise ValueError(
160
+ f"Fine tune {fine_tune_id} not completed. Refresh it's status in the fine-tune tab."
161
+ )
162
+
163
+ provider = ModelProviderName[fine_tune.provider]
164
+ model_provider = KilnModelProvider(
165
+ name=provider,
166
+ provider_options={
167
+ "model": fine_tune.fine_tune_model_id,
168
+ },
169
+ )
170
+
171
+ # TODO: Don't love this abstraction/logic.
172
+ if fine_tune.provider == ModelProviderName.fireworks_ai:
173
+ # Fireworks finetunes are trained with json, not tool calling (which is LC default format)
174
+ model_provider.adapter_options = {
175
+ "langchain": {
176
+ "with_structured_output_options": {
177
+ "method": "json_mode",
178
+ }
179
+ }
180
+ }
181
+
182
+ finetune_cache[model_id] = model_provider
183
+ return model_provider
184
+
185
+
186
+ def get_model_and_provider(
187
+ model_name: str, provider_name: str
188
+ ) -> tuple[KilnModel | None, KilnModelProvider | None]:
189
+ model = next(filter(lambda m: m.name == model_name, built_in_models), None)
190
+ if model is None:
191
+ return None, None
192
+ provider = next(filter(lambda p: p.name == provider_name, model.providers), None)
193
+ # all or nothing
194
+ if provider is None or model is None:
195
+ return None, None
196
+ return model, provider
197
+
198
+
199
+ def provider_name_from_id(id: str) -> str:
200
+ """
201
+ Converts a provider ID to its human-readable name.
202
+
203
+ Args:
204
+ id: The provider identifier string
205
+
206
+ Returns:
207
+ The human-readable name of the provider
208
+
209
+ Raises:
210
+ ValueError: If the provider ID is invalid or unhandled
211
+ """
212
+ if id in ModelProviderName.__members__:
213
+ enum_id = ModelProviderName(id)
214
+ match enum_id:
215
+ case ModelProviderName.amazon_bedrock:
216
+ return "Amazon Bedrock"
217
+ case ModelProviderName.openrouter:
218
+ return "OpenRouter"
219
+ case ModelProviderName.groq:
220
+ return "Groq"
221
+ case ModelProviderName.ollama:
222
+ return "Ollama"
223
+ case ModelProviderName.openai:
224
+ return "OpenAI"
225
+ case ModelProviderName.kiln_fine_tune:
226
+ return "Fine Tuned Models"
227
+ case ModelProviderName.fireworks_ai:
228
+ return "Fireworks AI"
229
+ case ModelProviderName.kiln_custom_registry:
230
+ return "Custom Models"
231
+ case _:
232
+ # triggers pyright warning if I miss a case
233
+ raise_exhaustive_error(enum_id)
234
+
235
+ return "Unknown provider: " + id
236
+
237
+
238
+ def provider_options_for_custom_model(
239
+ model_name: str, provider_name: str
240
+ ) -> Dict[str, str]:
241
+ """
242
+ Generated model provider options for a custom model. Each has their own format/options.
243
+ """
244
+
245
+ if provider_name not in ModelProviderName.__members__:
246
+ raise ValueError(f"Invalid provider name: {provider_name}")
247
+
248
+ enum_id = ModelProviderName(provider_name)
249
+ match enum_id:
250
+ case ModelProviderName.amazon_bedrock:
251
+ # us-west-2 is the only region consistently supported by Bedrock
252
+ return {"model": model_name, "region_name": "us-west-2"}
253
+ case (
254
+ ModelProviderName.openai
255
+ | ModelProviderName.ollama
256
+ | ModelProviderName.fireworks_ai
257
+ | ModelProviderName.openrouter
258
+ | ModelProviderName.groq
259
+ ):
260
+ return {"model": model_name}
261
+ case ModelProviderName.kiln_custom_registry:
262
+ raise ValueError(
263
+ "Custom models from registry should be parsed into provider/model before calling this."
264
+ )
265
+ case ModelProviderName.kiln_fine_tune:
266
+ raise ValueError(
267
+ "Fine tuned models should populate provider options via another path"
268
+ )
269
+ case _:
270
+ # triggers pyright warning if I miss a case
271
+ raise_exhaustive_error(enum_id)
272
+
273
+ # Won't reach this, type checking will catch missed values
274
+ return {"model": model_name}
275
+
276
+
277
+ def raise_exhaustive_error(value: NoReturn) -> NoReturn:
278
+ raise ValueError(f"Unhandled enum value: {value}")
279
+
280
+
281
+ @dataclass
282
+ class ModelProviderWarning:
283
+ required_config_keys: List[str]
284
+ message: str
285
+
286
+
287
+ provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
288
+ ModelProviderName.amazon_bedrock: ModelProviderWarning(
289
+ required_config_keys=["bedrock_access_key", "bedrock_secret_key"],
290
+ message="Attempted to use Amazon Bedrock without an access key and secret set. \nGet your keys from https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/overview",
291
+ ),
292
+ ModelProviderName.openrouter: ModelProviderWarning(
293
+ required_config_keys=["open_router_api_key"],
294
+ message="Attempted to use OpenRouter without an API key set. \nGet your API key from https://openrouter.ai/settings/keys",
295
+ ),
296
+ ModelProviderName.groq: ModelProviderWarning(
297
+ required_config_keys=["groq_api_key"],
298
+ message="Attempted to use Groq without an API key set. \nGet your API key from https://console.groq.com/keys",
299
+ ),
300
+ ModelProviderName.openai: ModelProviderWarning(
301
+ required_config_keys=["open_ai_api_key"],
302
+ message="Attempted to use OpenAI without an API key set. \nGet your API key from https://platform.openai.com/account/api-keys",
303
+ ),
304
+ ModelProviderName.fireworks_ai: ModelProviderWarning(
305
+ required_config_keys=["fireworks_api_key", "fireworks_account_id"],
306
+ message="Attempted to use Fireworks without an API key and account ID set. \nGet your API key from https://fireworks.ai/account/api-keys and your account ID from https://fireworks.ai/account/profile",
307
+ ),
308
+ }
@@ -43,8 +43,10 @@ feedback describing what should be improved. Your job is to understand the evalu
43
43
  @classmethod
44
44
  def _original_prompt(cls, run: TaskRun, task: Task) -> str:
45
45
  prompt_builder_class: Type[BasePromptBuilder] | None = None
46
- prompt_builder_name = run.output.source.properties.get(
47
- "prompt_builder_name", None
46
+ prompt_builder_name = (
47
+ run.output.source.properties.get("prompt_builder_name", None)
48
+ if run.output.source
49
+ else None
48
50
  )
49
51
  if prompt_builder_name is not None and isinstance(prompt_builder_name, str):
50
52
  prompt_builder_class = prompt_builder_registry.get(
@@ -5,10 +5,9 @@ from unittest.mock import AsyncMock, patch
5
5
  import pytest
6
6
  from pydantic import ValidationError
7
7
 
8
+ from kiln_ai.adapters.adapter_registry import adapter_for_task
8
9
  from kiln_ai.adapters.base_adapter import RunOutput
9
- from kiln_ai.adapters.langchain_adapters import (
10
- LangChainPromptAdapter,
11
- )
10
+ from kiln_ai.adapters.langchain_adapters import LangchainAdapter
12
11
  from kiln_ai.adapters.repair.repair_task import (
13
12
  RepairTaskInput,
14
13
  RepairTaskRun,
@@ -60,7 +59,7 @@ json_joke_schema = """{
60
59
 
61
60
  @pytest.fixture
62
61
  def sample_task(tmp_path):
63
- task_path = tmp_path / "task.json"
62
+ task_path = tmp_path / "task.kiln"
64
63
  task = Task(
65
64
  name="Joke Generator",
66
65
  path=task_path,
@@ -190,9 +189,7 @@ async def test_live_run(sample_task, sample_task_run, sample_repair_data):
190
189
  repair_task_input = RepairTaskRun.build_repair_task_input(**sample_repair_data)
191
190
  assert isinstance(repair_task_input, RepairTaskInput)
192
191
 
193
- adapter = LangChainPromptAdapter(
194
- repair_task, model_name="llama_3_1_8b", provider="groq"
195
- )
192
+ adapter = adapter_for_task(repair_task, model_name="llama_3_1_8b", provider="groq")
196
193
 
197
194
  run = await adapter.invoke(repair_task_input.model_dump())
198
195
  assert run is not None
@@ -220,14 +217,12 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
220
217
  "rating": 8,
221
218
  }
222
219
 
223
- with patch.object(
224
- LangChainPromptAdapter, "_run", new_callable=AsyncMock
225
- ) as mock_run:
220
+ with patch.object(LangchainAdapter, "_run", new_callable=AsyncMock) as mock_run:
226
221
  mock_run.return_value = RunOutput(
227
222
  output=mocked_output, intermediate_outputs=None
228
223
  )
229
224
 
230
- adapter = LangChainPromptAdapter(
225
+ adapter = adapter_for_task(
231
226
  repair_task, model_name="llama_3_1_8b", provider="groq"
232
227
  )
233
228