kiln-ai 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (42) hide show
  1. kiln_ai/adapters/__init__.py +11 -1
  2. kiln_ai/adapters/adapter_registry.py +19 -0
  3. kiln_ai/adapters/data_gen/__init__.py +11 -0
  4. kiln_ai/adapters/data_gen/data_gen_task.py +69 -1
  5. kiln_ai/adapters/data_gen/test_data_gen_task.py +30 -21
  6. kiln_ai/adapters/fine_tune/__init__.py +14 -0
  7. kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
  8. kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
  9. kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
  10. kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
  11. kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
  12. kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
  13. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
  14. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
  15. kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
  16. kiln_ai/adapters/langchain_adapters.py +103 -13
  17. kiln_ai/adapters/ml_model_list.py +218 -304
  18. kiln_ai/adapters/ollama_tools.py +114 -0
  19. kiln_ai/adapters/provider_tools.py +295 -0
  20. kiln_ai/adapters/repair/test_repair_task.py +6 -11
  21. kiln_ai/adapters/test_langchain_adapter.py +46 -18
  22. kiln_ai/adapters/test_ollama_tools.py +42 -0
  23. kiln_ai/adapters/test_prompt_adaptors.py +7 -5
  24. kiln_ai/adapters/test_provider_tools.py +312 -0
  25. kiln_ai/adapters/test_structured_output.py +22 -43
  26. kiln_ai/datamodel/__init__.py +235 -22
  27. kiln_ai/datamodel/basemodel.py +30 -0
  28. kiln_ai/datamodel/registry.py +31 -0
  29. kiln_ai/datamodel/test_basemodel.py +29 -1
  30. kiln_ai/datamodel/test_dataset_split.py +234 -0
  31. kiln_ai/datamodel/test_example_models.py +12 -0
  32. kiln_ai/datamodel/test_models.py +91 -1
  33. kiln_ai/datamodel/test_registry.py +96 -0
  34. kiln_ai/utils/config.py +9 -0
  35. kiln_ai/utils/name_generator.py +125 -0
  36. kiln_ai/utils/test_name_geneator.py +47 -0
  37. {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/METADATA +4 -2
  38. kiln_ai-0.7.0.dist-info/RECORD +56 -0
  39. kiln_ai/adapters/test_ml_model_list.py +0 -181
  40. kiln_ai-0.6.0.dist-info/RECORD +0 -36
  41. {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/WHEEL +0 -0
  42. {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,114 @@
1
+ import os
2
+ from typing import Any, List
3
+
4
+ import httpx
5
+ import requests
6
+ from pydantic import BaseModel, Field
7
+
8
+ from kiln_ai.adapters.ml_model_list import ModelProviderName, built_in_models
9
+
10
+
11
+ def ollama_base_url() -> str:
12
+ """
13
+ Gets the base URL for Ollama API connections.
14
+
15
+ Returns:
16
+ The base URL to use for Ollama API calls, using environment variable if set
17
+ or falling back to localhost default
18
+ """
19
+ env_base_url = os.getenv("OLLAMA_BASE_URL")
20
+ if env_base_url is not None:
21
+ return env_base_url
22
+ return "http://localhost:11434"
23
+
24
+
25
+ async def ollama_online() -> bool:
26
+ """
27
+ Checks if the Ollama service is available and responding.
28
+
29
+ Returns:
30
+ True if Ollama is available and responding, False otherwise
31
+ """
32
+ try:
33
+ httpx.get(ollama_base_url() + "/api/tags")
34
+ except httpx.RequestError:
35
+ return False
36
+ return True
37
+
38
+
39
+ class OllamaConnection(BaseModel):
40
+ message: str
41
+ supported_models: List[str]
42
+ untested_models: List[str] = Field(default_factory=list)
43
+
44
+ def all_models(self) -> List[str]:
45
+ return self.supported_models + self.untested_models
46
+
47
+
48
+ # Parse the Ollama /api/tags response
49
+ def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
50
+ # Build a list of models we support for Ollama from the built-in model list
51
+ supported_ollama_models = [
52
+ provider.provider_options["model"]
53
+ for model in built_in_models
54
+ for provider in model.providers
55
+ if provider.name == ModelProviderName.ollama
56
+ ]
57
+ # Append model_aliases to supported_ollama_models
58
+ supported_ollama_models.extend(
59
+ [
60
+ alias
61
+ for model in built_in_models
62
+ for provider in model.providers
63
+ for alias in provider.provider_options.get("model_aliases", [])
64
+ ]
65
+ )
66
+
67
+ if "models" in tags:
68
+ models = tags["models"]
69
+ if isinstance(models, list):
70
+ model_names = [model["model"] for model in models]
71
+ available_supported_models = []
72
+ untested_models = []
73
+ supported_models_latest_aliases = [
74
+ f"{m}:latest" for m in supported_ollama_models
75
+ ]
76
+ for model in model_names:
77
+ if (
78
+ model in supported_ollama_models
79
+ or model in supported_models_latest_aliases
80
+ ):
81
+ available_supported_models.append(model)
82
+ else:
83
+ untested_models.append(model)
84
+
85
+ if available_supported_models or untested_models:
86
+ return OllamaConnection(
87
+ message="Ollama connected",
88
+ supported_models=available_supported_models,
89
+ untested_models=untested_models,
90
+ )
91
+
92
+ return OllamaConnection(
93
+ message="Ollama is running, but no supported models are installed. Install one or more supported model, like 'ollama pull phi3.5'.",
94
+ supported_models=[],
95
+ untested_models=[],
96
+ )
97
+
98
+
99
+ async def get_ollama_connection() -> OllamaConnection | None:
100
+ """
101
+ Gets the connection status for Ollama.
102
+ """
103
+ try:
104
+ tags = requests.get(ollama_base_url() + "/api/tags", timeout=5).json()
105
+
106
+ except Exception:
107
+ return None
108
+
109
+ return parse_ollama_tags(tags)
110
+
111
+
112
+ def ollama_model_installed(conn: OllamaConnection, model_name: str) -> bool:
113
+ all_models = conn.all_models()
114
+ return model_name in all_models or f"{model_name}:latest" in all_models
@@ -0,0 +1,295 @@
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, NoReturn
3
+
4
+ from kiln_ai.adapters.ml_model_list import (
5
+ KilnModel,
6
+ KilnModelProvider,
7
+ ModelName,
8
+ ModelProviderName,
9
+ built_in_models,
10
+ )
11
+ from kiln_ai.adapters.ollama_tools import (
12
+ get_ollama_connection,
13
+ )
14
+ from kiln_ai.datamodel.registry import project_from_id
15
+
16
+ from ..utils.config import Config
17
+
18
+
19
+ async def provider_enabled(provider_name: ModelProviderName) -> bool:
20
+ if provider_name == ModelProviderName.ollama:
21
+ try:
22
+ conn = await get_ollama_connection()
23
+ return conn is not None and (
24
+ len(conn.supported_models) > 0 or len(conn.untested_models) > 0
25
+ )
26
+ except Exception:
27
+ return False
28
+
29
+ provider_warning = provider_warnings.get(provider_name)
30
+ if provider_warning is None:
31
+ return False
32
+ for required_key in provider_warning.required_config_keys:
33
+ if get_config_value(required_key) is None:
34
+ return False
35
+ return True
36
+
37
+
38
+ def get_config_value(key: str):
39
+ try:
40
+ return Config.shared().__getattr__(key)
41
+ except AttributeError:
42
+ return None
43
+
44
+
45
+ def check_provider_warnings(provider_name: ModelProviderName):
46
+ """
47
+ Validates that required configuration is present for a given provider.
48
+
49
+ Args:
50
+ provider_name: The provider to check
51
+
52
+ Raises:
53
+ ValueError: If required configuration keys are missing
54
+ """
55
+ warning_check = provider_warnings.get(provider_name)
56
+ if warning_check is None:
57
+ return
58
+ for key in warning_check.required_config_keys:
59
+ if get_config_value(key) is None:
60
+ raise ValueError(warning_check.message)
61
+
62
+
63
+ async def builtin_model_from(
64
+ name: str, provider_name: str | None = None
65
+ ) -> KilnModelProvider | None:
66
+ """
67
+ Gets a model and provider from the built-in list of models.
68
+
69
+ Args:
70
+ name: The name of the model to get
71
+ provider_name: Optional specific provider to use (defaults to first available)
72
+
73
+ Returns:
74
+ A tuple of (provider, model)
75
+
76
+ Raises:
77
+ ValueError: If the model or provider is not found, or if the provider is misconfigured
78
+ """
79
+ if name not in ModelName.__members__:
80
+ return None
81
+
82
+ # Select the model from built_in_models using the name
83
+ model = next(filter(lambda m: m.name == name, built_in_models))
84
+ if model is None:
85
+ raise ValueError(f"Model {name} not found")
86
+
87
+ # If a provider is provided, select the provider from the model's provider_config
88
+ provider: KilnModelProvider | None = None
89
+ if model.providers is None or len(model.providers) == 0:
90
+ raise ValueError(f"Model {name} has no providers")
91
+ elif provider_name is None:
92
+ provider = model.providers[0]
93
+ else:
94
+ provider = next(
95
+ filter(lambda p: p.name == provider_name, model.providers), None
96
+ )
97
+ if provider is None:
98
+ return None
99
+
100
+ check_provider_warnings(provider.name)
101
+ return provider
102
+
103
+
104
+ async def kiln_model_provider_from(
105
+ name: str, provider_name: str | None = None
106
+ ) -> KilnModelProvider:
107
+ if provider_name == ModelProviderName.kiln_fine_tune:
108
+ return finetune_provider_model(name)
109
+
110
+ built_in_model = await builtin_model_from(name, provider_name)
111
+ if built_in_model:
112
+ return built_in_model
113
+
114
+ # Custom/untested model. Set untested, and build a ModelProvider at runtime
115
+ if provider_name is None:
116
+ raise ValueError("Provider name is required for custom models")
117
+ if provider_name not in ModelProviderName.__members__:
118
+ raise ValueError(f"Invalid provider name: {provider_name}")
119
+ provider = ModelProviderName(provider_name)
120
+ check_provider_warnings(provider)
121
+ return KilnModelProvider(
122
+ name=provider,
123
+ supports_structured_output=False,
124
+ supports_data_gen=False,
125
+ untested_model=True,
126
+ provider_options=provider_options_for_custom_model(name, provider_name),
127
+ )
128
+
129
+
130
+ finetune_cache: dict[str, KilnModelProvider] = {}
131
+
132
+
133
+ def finetune_provider_model(
134
+ model_id: str,
135
+ ) -> KilnModelProvider:
136
+ if model_id in finetune_cache:
137
+ return finetune_cache[model_id]
138
+
139
+ try:
140
+ project_id, task_id, fine_tune_id = model_id.split("::")
141
+ except Exception:
142
+ raise ValueError(f"Invalid fine tune ID: {model_id}")
143
+ project = project_from_id(project_id)
144
+ if project is None:
145
+ raise ValueError(f"Project {project_id} not found")
146
+ task = next((t for t in project.tasks() if t.id == task_id), None)
147
+ if task is None:
148
+ raise ValueError(f"Task {task_id} not found")
149
+ fine_tune = next((f for f in task.finetunes() if f.id == fine_tune_id), None)
150
+ if fine_tune is None:
151
+ raise ValueError(f"Fine tune {fine_tune_id} not found")
152
+ if fine_tune.fine_tune_model_id is None:
153
+ raise ValueError(
154
+ f"Fine tune {fine_tune_id} not completed. Refresh it's status in the fine-tune tab."
155
+ )
156
+
157
+ provider = ModelProviderName[fine_tune.provider]
158
+ model_provider = KilnModelProvider(
159
+ name=provider,
160
+ provider_options={
161
+ "model": fine_tune.fine_tune_model_id,
162
+ },
163
+ )
164
+
165
+ # TODO: Don't love this abstraction/logic.
166
+ if fine_tune.provider == ModelProviderName.fireworks_ai:
167
+ # Fireworks finetunes are trained with json, not tool calling (which is LC default format)
168
+ model_provider.adapter_options = {
169
+ "langchain": {
170
+ "with_structured_output_options": {
171
+ "method": "json_mode",
172
+ }
173
+ }
174
+ }
175
+
176
+ finetune_cache[model_id] = model_provider
177
+ return model_provider
178
+
179
+
180
+ def get_model_and_provider(
181
+ model_name: str, provider_name: str
182
+ ) -> tuple[KilnModel | None, KilnModelProvider | None]:
183
+ model = next(filter(lambda m: m.name == model_name, built_in_models), None)
184
+ if model is None:
185
+ return None, None
186
+ provider = next(filter(lambda p: p.name == provider_name, model.providers), None)
187
+ # all or nothing
188
+ if provider is None or model is None:
189
+ return None, None
190
+ return model, provider
191
+
192
+
193
+ def provider_name_from_id(id: str) -> str:
194
+ """
195
+ Converts a provider ID to its human-readable name.
196
+
197
+ Args:
198
+ id: The provider identifier string
199
+
200
+ Returns:
201
+ The human-readable name of the provider
202
+
203
+ Raises:
204
+ ValueError: If the provider ID is invalid or unhandled
205
+ """
206
+ if id in ModelProviderName.__members__:
207
+ enum_id = ModelProviderName(id)
208
+ match enum_id:
209
+ case ModelProviderName.amazon_bedrock:
210
+ return "Amazon Bedrock"
211
+ case ModelProviderName.openrouter:
212
+ return "OpenRouter"
213
+ case ModelProviderName.groq:
214
+ return "Groq"
215
+ case ModelProviderName.ollama:
216
+ return "Ollama"
217
+ case ModelProviderName.openai:
218
+ return "OpenAI"
219
+ case ModelProviderName.kiln_fine_tune:
220
+ return "Fine Tuned Models"
221
+ case ModelProviderName.fireworks_ai:
222
+ return "Fireworks AI"
223
+ case _:
224
+ # triggers pyright warning if I miss a case
225
+ raise_exhaustive_error(enum_id)
226
+
227
+ return "Unknown provider: " + id
228
+
229
+
230
+ def provider_options_for_custom_model(
231
+ model_name: str, provider_name: str
232
+ ) -> Dict[str, str]:
233
+ """
234
+ Generated model provider options for a custom model. Each has their own format/options.
235
+ """
236
+ if provider_name not in ModelProviderName.__members__:
237
+ raise ValueError(f"Invalid provider name: {provider_name}")
238
+
239
+ enum_id = ModelProviderName(provider_name)
240
+ match enum_id:
241
+ case ModelProviderName.amazon_bedrock:
242
+ # us-west-2 is the only region consistently supported by Bedrock
243
+ return {"model": model_name, "region_name": "us-west-2"}
244
+ case (
245
+ ModelProviderName.openai
246
+ | ModelProviderName.ollama
247
+ | ModelProviderName.fireworks_ai
248
+ | ModelProviderName.openrouter
249
+ | ModelProviderName.groq
250
+ ):
251
+ return {"model": model_name}
252
+ case ModelProviderName.kiln_fine_tune:
253
+ raise ValueError(
254
+ "Fine tuned models should populate provider options via another path"
255
+ )
256
+ case _:
257
+ # triggers pyright warning if I miss a case
258
+ raise_exhaustive_error(enum_id)
259
+
260
+ # Won't reach this, type checking will catch missed values
261
+ return {"model": model_name}
262
+
263
+
264
+ def raise_exhaustive_error(value: NoReturn) -> NoReturn:
265
+ raise ValueError(f"Unhandled enum value: {value}")
266
+
267
+
268
+ @dataclass
269
+ class ModelProviderWarning:
270
+ required_config_keys: List[str]
271
+ message: str
272
+
273
+
274
+ provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
275
+ ModelProviderName.amazon_bedrock: ModelProviderWarning(
276
+ required_config_keys=["bedrock_access_key", "bedrock_secret_key"],
277
+ message="Attempted to use Amazon Bedrock without an access key and secret set. \nGet your keys from https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/overview",
278
+ ),
279
+ ModelProviderName.openrouter: ModelProviderWarning(
280
+ required_config_keys=["open_router_api_key"],
281
+ message="Attempted to use OpenRouter without an API key set. \nGet your API key from https://openrouter.ai/settings/keys",
282
+ ),
283
+ ModelProviderName.groq: ModelProviderWarning(
284
+ required_config_keys=["groq_api_key"],
285
+ message="Attempted to use Groq without an API key set. \nGet your API key from https://console.groq.com/keys",
286
+ ),
287
+ ModelProviderName.openai: ModelProviderWarning(
288
+ required_config_keys=["open_ai_api_key"],
289
+ message="Attempted to use OpenAI without an API key set. \nGet your API key from https://platform.openai.com/account/api-keys",
290
+ ),
291
+ ModelProviderName.fireworks_ai: ModelProviderWarning(
292
+ required_config_keys=["fireworks_api_key", "fireworks_account_id"],
293
+ message="Attempted to use Fireworks without an API key and account ID set. \nGet your API key from https://fireworks.ai/account/api-keys and your account ID from https://fireworks.ai/account/profile",
294
+ ),
295
+ }
@@ -5,10 +5,9 @@ from unittest.mock import AsyncMock, patch
5
5
  import pytest
6
6
  from pydantic import ValidationError
7
7
 
8
+ from kiln_ai.adapters.adapter_registry import adapter_for_task
8
9
  from kiln_ai.adapters.base_adapter import RunOutput
9
- from kiln_ai.adapters.langchain_adapters import (
10
- LangChainPromptAdapter,
11
- )
10
+ from kiln_ai.adapters.langchain_adapters import LangchainAdapter
12
11
  from kiln_ai.adapters.repair.repair_task import (
13
12
  RepairTaskInput,
14
13
  RepairTaskRun,
@@ -60,7 +59,7 @@ json_joke_schema = """{
60
59
 
61
60
  @pytest.fixture
62
61
  def sample_task(tmp_path):
63
- task_path = tmp_path / "task.json"
62
+ task_path = tmp_path / "task.kiln"
64
63
  task = Task(
65
64
  name="Joke Generator",
66
65
  path=task_path,
@@ -190,9 +189,7 @@ async def test_live_run(sample_task, sample_task_run, sample_repair_data):
190
189
  repair_task_input = RepairTaskRun.build_repair_task_input(**sample_repair_data)
191
190
  assert isinstance(repair_task_input, RepairTaskInput)
192
191
 
193
- adapter = LangChainPromptAdapter(
194
- repair_task, model_name="llama_3_1_8b", provider="groq"
195
- )
192
+ adapter = adapter_for_task(repair_task, model_name="llama_3_1_8b", provider="groq")
196
193
 
197
194
  run = await adapter.invoke(repair_task_input.model_dump())
198
195
  assert run is not None
@@ -220,14 +217,12 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
220
217
  "rating": 8,
221
218
  }
222
219
 
223
- with patch.object(
224
- LangChainPromptAdapter, "_run", new_callable=AsyncMock
225
- ) as mock_run:
220
+ with patch.object(LangchainAdapter, "_run", new_callable=AsyncMock) as mock_run:
226
221
  mock_run.return_value = RunOutput(
227
222
  output=mocked_output, intermediate_outputs=None
228
223
  )
229
224
 
230
- adapter = LangChainPromptAdapter(
225
+ adapter = adapter_for_task(
231
226
  repair_task, model_name="llama_3_1_8b", provider="groq"
232
227
  )
233
228
 
@@ -3,16 +3,17 @@ from unittest.mock import AsyncMock, MagicMock, patch
3
3
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
4
4
  from langchain_groq import ChatGroq
5
5
 
6
- from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
6
+ from kiln_ai.adapters.langchain_adapters import (
7
+ LangchainAdapter,
8
+ get_structured_output_options,
9
+ )
7
10
  from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder
8
11
  from kiln_ai.adapters.test_prompt_adaptors import build_test_task
9
12
 
10
13
 
11
14
  def test_langchain_adapter_munge_response(tmp_path):
12
15
  task = build_test_task(tmp_path)
13
- lca = LangChainPromptAdapter(
14
- kiln_task=task, model_name="llama_3_1_8b", provider="ollama"
15
- )
16
+ lca = LangchainAdapter(kiln_task=task, model_name="llama_3_1_8b", provider="ollama")
16
17
  # Mistral Large tool calling format is a bit different
17
18
  response = {
18
19
  "name": "task_response",
@@ -35,7 +36,7 @@ def test_langchain_adapter_infer_model_name(tmp_path):
35
36
  task = build_test_task(tmp_path)
36
37
  custom = ChatGroq(model="llama-3.1-8b-instant", groq_api_key="test")
37
38
 
38
- lca = LangChainPromptAdapter(kiln_task=task, custom_model=custom)
39
+ lca = LangchainAdapter(kiln_task=task, custom_model=custom)
39
40
 
40
41
  model_info = lca.adapter_info()
41
42
  assert model_info.model_name == "custom.langchain:llama-3.1-8b-instant"
@@ -45,9 +46,7 @@ def test_langchain_adapter_infer_model_name(tmp_path):
45
46
  def test_langchain_adapter_info(tmp_path):
46
47
  task = build_test_task(tmp_path)
47
48
 
48
- lca = LangChainPromptAdapter(
49
- kiln_task=task, model_name="llama_3_1_8b", provider="ollama"
50
- )
49
+ lca = LangchainAdapter(kiln_task=task, model_name="llama_3_1_8b", provider="ollama")
51
50
 
52
51
  model_info = lca.adapter_info()
53
52
  assert model_info.adapter_name == "kiln_langchain_adapter"
@@ -60,7 +59,7 @@ async def test_langchain_adapter_with_cot(tmp_path):
60
59
  task.output_json_schema = (
61
60
  '{"type": "object", "properties": {"count": {"type": "integer"}}}'
62
61
  )
63
- lca = LangChainPromptAdapter(
62
+ lca = LangchainAdapter(
64
63
  kiln_task=task,
65
64
  model_name="llama_3_1_8b",
66
65
  provider="ollama",
@@ -69,13 +68,13 @@ async def test_langchain_adapter_with_cot(tmp_path):
69
68
 
70
69
  # Mock the base model and its invoke method
71
70
  mock_base_model = MagicMock()
72
- mock_base_model.invoke.return_value = AIMessage(
73
- content="Chain of thought reasoning..."
71
+ mock_base_model.ainvoke = AsyncMock(
72
+ return_value=AIMessage(content="Chain of thought reasoning...")
74
73
  )
75
74
 
76
75
  # Create a separate mock for self.model()
77
76
  mock_model_instance = MagicMock()
78
- mock_model_instance.invoke.return_value = {"parsed": {"count": 1}}
77
+ mock_model_instance.ainvoke = AsyncMock(return_value={"parsed": {"count": 1}})
79
78
 
80
79
  # Mock the langchain_model_from function to return the base model
81
80
  mock_model_from = AsyncMock(return_value=mock_base_model)
@@ -85,14 +84,14 @@ async def test_langchain_adapter_with_cot(tmp_path):
85
84
  patch(
86
85
  "kiln_ai.adapters.langchain_adapters.langchain_model_from", mock_model_from
87
86
  ),
88
- patch.object(LangChainPromptAdapter, "model", return_value=mock_model_instance),
87
+ patch.object(LangchainAdapter, "model", return_value=mock_model_instance),
89
88
  ):
90
89
  response = await lca._run("test input")
91
90
 
92
91
  # First 3 messages are the same for both calls
93
92
  for invoke_args in [
94
- mock_base_model.invoke.call_args[0][0],
95
- mock_model_instance.invoke.call_args[0][0],
93
+ mock_base_model.ainvoke.call_args[0][0],
94
+ mock_model_instance.ainvoke.call_args[0][0],
96
95
  ]:
97
96
  assert isinstance(
98
97
  invoke_args[0], SystemMessage
@@ -107,11 +106,11 @@ async def test_langchain_adapter_with_cot(tmp_path):
107
106
  assert "step by step" in invoke_args[2].content
108
107
 
109
108
  # the COT should only have 3 messages
110
- assert len(mock_base_model.invoke.call_args[0][0]) == 3
111
- assert len(mock_model_instance.invoke.call_args[0][0]) == 5
109
+ assert len(mock_base_model.ainvoke.call_args[0][0]) == 3
110
+ assert len(mock_model_instance.ainvoke.call_args[0][0]) == 5
112
111
 
113
112
  # the final response should have the COT content and the final instructions
114
- invoke_args = mock_model_instance.invoke.call_args[0][0]
113
+ invoke_args = mock_model_instance.ainvoke.call_args[0][0]
115
114
  assert isinstance(invoke_args[3], AIMessage)
116
115
  assert "Chain of thought reasoning..." in invoke_args[3].content
117
116
  assert isinstance(invoke_args[4], SystemMessage)
@@ -122,3 +121,32 @@ async def test_langchain_adapter_with_cot(tmp_path):
122
121
  == "Chain of thought reasoning..."
123
122
  )
124
123
  assert response.output == {"count": 1}
124
+
125
+
126
+ async def test_get_structured_output_options():
127
+ # Mock the provider response
128
+ mock_provider = MagicMock()
129
+ mock_provider.adapter_options = {
130
+ "langchain": {
131
+ "with_structured_output_options": {
132
+ "force_json_response": True,
133
+ "max_retries": 3,
134
+ }
135
+ }
136
+ }
137
+
138
+ # Test with provider that has options
139
+ with patch(
140
+ "kiln_ai.adapters.langchain_adapters.kiln_model_provider_from",
141
+ AsyncMock(return_value=mock_provider),
142
+ ):
143
+ options = await get_structured_output_options("model_name", "provider")
144
+ assert options == {"force_json_response": True, "max_retries": 3}
145
+
146
+ # Test with provider that has no options
147
+ with patch(
148
+ "kiln_ai.adapters.langchain_adapters.kiln_model_provider_from",
149
+ AsyncMock(return_value=None),
150
+ ):
151
+ options = await get_structured_output_options("model_name", "provider")
152
+ assert options == {}
@@ -0,0 +1,42 @@
1
+ import json
2
+
3
+ from kiln_ai.adapters.ollama_tools import (
4
+ OllamaConnection,
5
+ ollama_model_installed,
6
+ parse_ollama_tags,
7
+ )
8
+
9
+
10
+ def test_parse_ollama_tags_no_models():
11
+ json_response = '{"models":[{"name":"scosman_net","model":"scosman_net:latest"},{"name":"phi3.5:latest","model":"phi3.5:latest","modified_at":"2024-10-02T12:04:35.191519822-04:00","size":2176178843,"digest":"61819fb370a3c1a9be6694869331e5f85f867a079e9271d66cb223acb81d04ba","details":{"parent_model":"","format":"gguf","family":"phi3","families":["phi3"],"parameter_size":"3.8B","quantization_level":"Q4_0"}},{"name":"gemma2:2b","model":"gemma2:2b","modified_at":"2024-09-09T16:46:38.64348929-04:00","size":1629518495,"digest":"8ccf136fdd5298f3ffe2d69862750ea7fb56555fa4d5b18c04e3fa4d82ee09d7","details":{"parent_model":"","format":"gguf","family":"gemma2","families":["gemma2"],"parameter_size":"2.6B","quantization_level":"Q4_0"}},{"name":"llama3.1:latest","model":"llama3.1:latest","modified_at":"2024-09-01T17:19:43.481523695-04:00","size":4661230720,"digest":"f66fc8dc39ea206e03ff6764fcc696b1b4dfb693f0b6ef751731dd4e6269046e","details":{"parent_model":"","format":"gguf","family":"llama","families":["llama"],"parameter_size":"8.0B","quantization_level":"Q4_0"}}]}'
12
+ tags = json.loads(json_response)
13
+ print(json.dumps(tags, indent=2))
14
+ conn = parse_ollama_tags(tags)
15
+ assert "phi3.5:latest" in conn.supported_models
16
+ assert "gemma2:2b" in conn.supported_models
17
+ assert "llama3.1:latest" in conn.supported_models
18
+ assert "scosman_net:latest" in conn.untested_models
19
+
20
+
21
+ def test_parse_ollama_tags_only_untested_models():
22
+ json_response = '{"models":[{"name":"scosman_net","model":"scosman_net:latest"}]}'
23
+ tags = json.loads(json_response)
24
+ conn = parse_ollama_tags(tags)
25
+ assert conn.supported_models == []
26
+ assert conn.untested_models == ["scosman_net:latest"]
27
+
28
+
29
+ def test_ollama_model_installed():
30
+ conn = OllamaConnection(
31
+ supported_models=["phi3.5:latest", "gemma2:2b", "llama3.1:latest"],
32
+ message="Connected",
33
+ untested_models=["scosman_net:latest"],
34
+ )
35
+ assert ollama_model_installed(conn, "phi3.5:latest")
36
+ assert ollama_model_installed(conn, "phi3.5")
37
+ assert ollama_model_installed(conn, "gemma2:2b")
38
+ assert ollama_model_installed(conn, "llama3.1:latest")
39
+ assert ollama_model_installed(conn, "llama3.1")
40
+ assert ollama_model_installed(conn, "scosman_net:latest")
41
+ assert ollama_model_installed(conn, "scosman_net")
42
+ assert not ollama_model_installed(conn, "unknown_model")
@@ -5,8 +5,10 @@ import pytest
5
5
  from langchain_core.language_models.fake_chat_models import FakeListChatModel
6
6
 
7
7
  import kiln_ai.datamodel as datamodel
8
- from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
9
- from kiln_ai.adapters.ml_model_list import built_in_models, ollama_online
8
+ from kiln_ai.adapters.adapter_registry import adapter_for_task
9
+ from kiln_ai.adapters.langchain_adapters import LangchainAdapter
10
+ from kiln_ai.adapters.ml_model_list import built_in_models
11
+ from kiln_ai.adapters.ollama_tools import ollama_online
10
12
  from kiln_ai.adapters.prompt_builders import (
11
13
  BasePromptBuilder,
12
14
  SimpleChainOfThoughtPromptBuilder,
@@ -106,7 +108,7 @@ async def test_amazon_bedrock(tmp_path):
106
108
  async def test_mock(tmp_path):
107
109
  task = build_test_task(tmp_path)
108
110
  mockChatModel = FakeListChatModel(responses=["mock response"])
109
- adapter = LangChainPromptAdapter(task, custom_model=mockChatModel)
111
+ adapter = LangchainAdapter(task, custom_model=mockChatModel)
110
112
  run = await adapter.invoke("You are a mock, send me the response!")
111
113
  assert "mock response" in run.output.output
112
114
 
@@ -114,7 +116,7 @@ async def test_mock(tmp_path):
114
116
  async def test_mock_returning_run(tmp_path):
115
117
  task = build_test_task(tmp_path)
116
118
  mockChatModel = FakeListChatModel(responses=["mock response"])
117
- adapter = LangChainPromptAdapter(task, custom_model=mockChatModel)
119
+ adapter = LangchainAdapter(task, custom_model=mockChatModel)
118
120
  run = await adapter.invoke("You are a mock, send me the response!")
119
121
  assert run.output.output == "mock response"
120
122
  assert run is not None
@@ -192,7 +194,7 @@ async def run_simple_task(
192
194
  provider: str,
193
195
  prompt_builder: BasePromptBuilder | None = None,
194
196
  ) -> datamodel.TaskRun:
195
- adapter = LangChainPromptAdapter(
197
+ adapter = adapter_for_task(
196
198
  task, model_name=model_name, provider=provider, prompt_builder=prompt_builder
197
199
  )
198
200