kiln-ai 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +11 -1
- kiln_ai/adapters/adapter_registry.py +19 -0
- kiln_ai/adapters/data_gen/__init__.py +11 -0
- kiln_ai/adapters/data_gen/data_gen_task.py +69 -1
- kiln_ai/adapters/data_gen/test_data_gen_task.py +30 -21
- kiln_ai/adapters/fine_tune/__init__.py +14 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
- kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
- kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
- kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
- kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
- kiln_ai/adapters/langchain_adapters.py +103 -13
- kiln_ai/adapters/ml_model_list.py +218 -304
- kiln_ai/adapters/ollama_tools.py +114 -0
- kiln_ai/adapters/provider_tools.py +295 -0
- kiln_ai/adapters/repair/test_repair_task.py +6 -11
- kiln_ai/adapters/test_langchain_adapter.py +46 -18
- kiln_ai/adapters/test_ollama_tools.py +42 -0
- kiln_ai/adapters/test_prompt_adaptors.py +7 -5
- kiln_ai/adapters/test_provider_tools.py +312 -0
- kiln_ai/adapters/test_structured_output.py +22 -43
- kiln_ai/datamodel/__init__.py +235 -22
- kiln_ai/datamodel/basemodel.py +30 -0
- kiln_ai/datamodel/registry.py +31 -0
- kiln_ai/datamodel/test_basemodel.py +29 -1
- kiln_ai/datamodel/test_dataset_split.py +234 -0
- kiln_ai/datamodel/test_example_models.py +12 -0
- kiln_ai/datamodel/test_models.py +91 -1
- kiln_ai/datamodel/test_registry.py +96 -0
- kiln_ai/utils/config.py +9 -0
- kiln_ai/utils/name_generator.py +125 -0
- kiln_ai/utils/test_name_geneator.py +47 -0
- {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/METADATA +4 -2
- kiln_ai-0.7.0.dist-info/RECORD +56 -0
- kiln_ai/adapters/test_ml_model_list.py +0 -181
- kiln_ai-0.6.0.dist-info/RECORD +0 -36
- {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Any, List
|
|
3
|
+
|
|
4
|
+
import httpx
|
|
5
|
+
import requests
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
|
|
8
|
+
from kiln_ai.adapters.ml_model_list import ModelProviderName, built_in_models
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def ollama_base_url() -> str:
|
|
12
|
+
"""
|
|
13
|
+
Gets the base URL for Ollama API connections.
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
The base URL to use for Ollama API calls, using environment variable if set
|
|
17
|
+
or falling back to localhost default
|
|
18
|
+
"""
|
|
19
|
+
env_base_url = os.getenv("OLLAMA_BASE_URL")
|
|
20
|
+
if env_base_url is not None:
|
|
21
|
+
return env_base_url
|
|
22
|
+
return "http://localhost:11434"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
async def ollama_online() -> bool:
|
|
26
|
+
"""
|
|
27
|
+
Checks if the Ollama service is available and responding.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
True if Ollama is available and responding, False otherwise
|
|
31
|
+
"""
|
|
32
|
+
try:
|
|
33
|
+
httpx.get(ollama_base_url() + "/api/tags")
|
|
34
|
+
except httpx.RequestError:
|
|
35
|
+
return False
|
|
36
|
+
return True
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class OllamaConnection(BaseModel):
|
|
40
|
+
message: str
|
|
41
|
+
supported_models: List[str]
|
|
42
|
+
untested_models: List[str] = Field(default_factory=list)
|
|
43
|
+
|
|
44
|
+
def all_models(self) -> List[str]:
|
|
45
|
+
return self.supported_models + self.untested_models
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# Parse the Ollama /api/tags response
|
|
49
|
+
def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
|
|
50
|
+
# Build a list of models we support for Ollama from the built-in model list
|
|
51
|
+
supported_ollama_models = [
|
|
52
|
+
provider.provider_options["model"]
|
|
53
|
+
for model in built_in_models
|
|
54
|
+
for provider in model.providers
|
|
55
|
+
if provider.name == ModelProviderName.ollama
|
|
56
|
+
]
|
|
57
|
+
# Append model_aliases to supported_ollama_models
|
|
58
|
+
supported_ollama_models.extend(
|
|
59
|
+
[
|
|
60
|
+
alias
|
|
61
|
+
for model in built_in_models
|
|
62
|
+
for provider in model.providers
|
|
63
|
+
for alias in provider.provider_options.get("model_aliases", [])
|
|
64
|
+
]
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
if "models" in tags:
|
|
68
|
+
models = tags["models"]
|
|
69
|
+
if isinstance(models, list):
|
|
70
|
+
model_names = [model["model"] for model in models]
|
|
71
|
+
available_supported_models = []
|
|
72
|
+
untested_models = []
|
|
73
|
+
supported_models_latest_aliases = [
|
|
74
|
+
f"{m}:latest" for m in supported_ollama_models
|
|
75
|
+
]
|
|
76
|
+
for model in model_names:
|
|
77
|
+
if (
|
|
78
|
+
model in supported_ollama_models
|
|
79
|
+
or model in supported_models_latest_aliases
|
|
80
|
+
):
|
|
81
|
+
available_supported_models.append(model)
|
|
82
|
+
else:
|
|
83
|
+
untested_models.append(model)
|
|
84
|
+
|
|
85
|
+
if available_supported_models or untested_models:
|
|
86
|
+
return OllamaConnection(
|
|
87
|
+
message="Ollama connected",
|
|
88
|
+
supported_models=available_supported_models,
|
|
89
|
+
untested_models=untested_models,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
return OllamaConnection(
|
|
93
|
+
message="Ollama is running, but no supported models are installed. Install one or more supported model, like 'ollama pull phi3.5'.",
|
|
94
|
+
supported_models=[],
|
|
95
|
+
untested_models=[],
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
async def get_ollama_connection() -> OllamaConnection | None:
|
|
100
|
+
"""
|
|
101
|
+
Gets the connection status for Ollama.
|
|
102
|
+
"""
|
|
103
|
+
try:
|
|
104
|
+
tags = requests.get(ollama_base_url() + "/api/tags", timeout=5).json()
|
|
105
|
+
|
|
106
|
+
except Exception:
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
return parse_ollama_tags(tags)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def ollama_model_installed(conn: OllamaConnection, model_name: str) -> bool:
|
|
113
|
+
all_models = conn.all_models()
|
|
114
|
+
return model_name in all_models or f"{model_name}:latest" in all_models
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Dict, List, NoReturn
|
|
3
|
+
|
|
4
|
+
from kiln_ai.adapters.ml_model_list import (
|
|
5
|
+
KilnModel,
|
|
6
|
+
KilnModelProvider,
|
|
7
|
+
ModelName,
|
|
8
|
+
ModelProviderName,
|
|
9
|
+
built_in_models,
|
|
10
|
+
)
|
|
11
|
+
from kiln_ai.adapters.ollama_tools import (
|
|
12
|
+
get_ollama_connection,
|
|
13
|
+
)
|
|
14
|
+
from kiln_ai.datamodel.registry import project_from_id
|
|
15
|
+
|
|
16
|
+
from ..utils.config import Config
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
async def provider_enabled(provider_name: ModelProviderName) -> bool:
|
|
20
|
+
if provider_name == ModelProviderName.ollama:
|
|
21
|
+
try:
|
|
22
|
+
conn = await get_ollama_connection()
|
|
23
|
+
return conn is not None and (
|
|
24
|
+
len(conn.supported_models) > 0 or len(conn.untested_models) > 0
|
|
25
|
+
)
|
|
26
|
+
except Exception:
|
|
27
|
+
return False
|
|
28
|
+
|
|
29
|
+
provider_warning = provider_warnings.get(provider_name)
|
|
30
|
+
if provider_warning is None:
|
|
31
|
+
return False
|
|
32
|
+
for required_key in provider_warning.required_config_keys:
|
|
33
|
+
if get_config_value(required_key) is None:
|
|
34
|
+
return False
|
|
35
|
+
return True
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_config_value(key: str):
|
|
39
|
+
try:
|
|
40
|
+
return Config.shared().__getattr__(key)
|
|
41
|
+
except AttributeError:
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def check_provider_warnings(provider_name: ModelProviderName):
|
|
46
|
+
"""
|
|
47
|
+
Validates that required configuration is present for a given provider.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
provider_name: The provider to check
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
ValueError: If required configuration keys are missing
|
|
54
|
+
"""
|
|
55
|
+
warning_check = provider_warnings.get(provider_name)
|
|
56
|
+
if warning_check is None:
|
|
57
|
+
return
|
|
58
|
+
for key in warning_check.required_config_keys:
|
|
59
|
+
if get_config_value(key) is None:
|
|
60
|
+
raise ValueError(warning_check.message)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
async def builtin_model_from(
|
|
64
|
+
name: str, provider_name: str | None = None
|
|
65
|
+
) -> KilnModelProvider | None:
|
|
66
|
+
"""
|
|
67
|
+
Gets a model and provider from the built-in list of models.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
name: The name of the model to get
|
|
71
|
+
provider_name: Optional specific provider to use (defaults to first available)
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
A tuple of (provider, model)
|
|
75
|
+
|
|
76
|
+
Raises:
|
|
77
|
+
ValueError: If the model or provider is not found, or if the provider is misconfigured
|
|
78
|
+
"""
|
|
79
|
+
if name not in ModelName.__members__:
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
# Select the model from built_in_models using the name
|
|
83
|
+
model = next(filter(lambda m: m.name == name, built_in_models))
|
|
84
|
+
if model is None:
|
|
85
|
+
raise ValueError(f"Model {name} not found")
|
|
86
|
+
|
|
87
|
+
# If a provider is provided, select the provider from the model's provider_config
|
|
88
|
+
provider: KilnModelProvider | None = None
|
|
89
|
+
if model.providers is None or len(model.providers) == 0:
|
|
90
|
+
raise ValueError(f"Model {name} has no providers")
|
|
91
|
+
elif provider_name is None:
|
|
92
|
+
provider = model.providers[0]
|
|
93
|
+
else:
|
|
94
|
+
provider = next(
|
|
95
|
+
filter(lambda p: p.name == provider_name, model.providers), None
|
|
96
|
+
)
|
|
97
|
+
if provider is None:
|
|
98
|
+
return None
|
|
99
|
+
|
|
100
|
+
check_provider_warnings(provider.name)
|
|
101
|
+
return provider
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
async def kiln_model_provider_from(
|
|
105
|
+
name: str, provider_name: str | None = None
|
|
106
|
+
) -> KilnModelProvider:
|
|
107
|
+
if provider_name == ModelProviderName.kiln_fine_tune:
|
|
108
|
+
return finetune_provider_model(name)
|
|
109
|
+
|
|
110
|
+
built_in_model = await builtin_model_from(name, provider_name)
|
|
111
|
+
if built_in_model:
|
|
112
|
+
return built_in_model
|
|
113
|
+
|
|
114
|
+
# Custom/untested model. Set untested, and build a ModelProvider at runtime
|
|
115
|
+
if provider_name is None:
|
|
116
|
+
raise ValueError("Provider name is required for custom models")
|
|
117
|
+
if provider_name not in ModelProviderName.__members__:
|
|
118
|
+
raise ValueError(f"Invalid provider name: {provider_name}")
|
|
119
|
+
provider = ModelProviderName(provider_name)
|
|
120
|
+
check_provider_warnings(provider)
|
|
121
|
+
return KilnModelProvider(
|
|
122
|
+
name=provider,
|
|
123
|
+
supports_structured_output=False,
|
|
124
|
+
supports_data_gen=False,
|
|
125
|
+
untested_model=True,
|
|
126
|
+
provider_options=provider_options_for_custom_model(name, provider_name),
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
finetune_cache: dict[str, KilnModelProvider] = {}
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def finetune_provider_model(
|
|
134
|
+
model_id: str,
|
|
135
|
+
) -> KilnModelProvider:
|
|
136
|
+
if model_id in finetune_cache:
|
|
137
|
+
return finetune_cache[model_id]
|
|
138
|
+
|
|
139
|
+
try:
|
|
140
|
+
project_id, task_id, fine_tune_id = model_id.split("::")
|
|
141
|
+
except Exception:
|
|
142
|
+
raise ValueError(f"Invalid fine tune ID: {model_id}")
|
|
143
|
+
project = project_from_id(project_id)
|
|
144
|
+
if project is None:
|
|
145
|
+
raise ValueError(f"Project {project_id} not found")
|
|
146
|
+
task = next((t for t in project.tasks() if t.id == task_id), None)
|
|
147
|
+
if task is None:
|
|
148
|
+
raise ValueError(f"Task {task_id} not found")
|
|
149
|
+
fine_tune = next((f for f in task.finetunes() if f.id == fine_tune_id), None)
|
|
150
|
+
if fine_tune is None:
|
|
151
|
+
raise ValueError(f"Fine tune {fine_tune_id} not found")
|
|
152
|
+
if fine_tune.fine_tune_model_id is None:
|
|
153
|
+
raise ValueError(
|
|
154
|
+
f"Fine tune {fine_tune_id} not completed. Refresh it's status in the fine-tune tab."
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
provider = ModelProviderName[fine_tune.provider]
|
|
158
|
+
model_provider = KilnModelProvider(
|
|
159
|
+
name=provider,
|
|
160
|
+
provider_options={
|
|
161
|
+
"model": fine_tune.fine_tune_model_id,
|
|
162
|
+
},
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# TODO: Don't love this abstraction/logic.
|
|
166
|
+
if fine_tune.provider == ModelProviderName.fireworks_ai:
|
|
167
|
+
# Fireworks finetunes are trained with json, not tool calling (which is LC default format)
|
|
168
|
+
model_provider.adapter_options = {
|
|
169
|
+
"langchain": {
|
|
170
|
+
"with_structured_output_options": {
|
|
171
|
+
"method": "json_mode",
|
|
172
|
+
}
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
finetune_cache[model_id] = model_provider
|
|
177
|
+
return model_provider
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def get_model_and_provider(
|
|
181
|
+
model_name: str, provider_name: str
|
|
182
|
+
) -> tuple[KilnModel | None, KilnModelProvider | None]:
|
|
183
|
+
model = next(filter(lambda m: m.name == model_name, built_in_models), None)
|
|
184
|
+
if model is None:
|
|
185
|
+
return None, None
|
|
186
|
+
provider = next(filter(lambda p: p.name == provider_name, model.providers), None)
|
|
187
|
+
# all or nothing
|
|
188
|
+
if provider is None or model is None:
|
|
189
|
+
return None, None
|
|
190
|
+
return model, provider
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def provider_name_from_id(id: str) -> str:
|
|
194
|
+
"""
|
|
195
|
+
Converts a provider ID to its human-readable name.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
id: The provider identifier string
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
The human-readable name of the provider
|
|
202
|
+
|
|
203
|
+
Raises:
|
|
204
|
+
ValueError: If the provider ID is invalid or unhandled
|
|
205
|
+
"""
|
|
206
|
+
if id in ModelProviderName.__members__:
|
|
207
|
+
enum_id = ModelProviderName(id)
|
|
208
|
+
match enum_id:
|
|
209
|
+
case ModelProviderName.amazon_bedrock:
|
|
210
|
+
return "Amazon Bedrock"
|
|
211
|
+
case ModelProviderName.openrouter:
|
|
212
|
+
return "OpenRouter"
|
|
213
|
+
case ModelProviderName.groq:
|
|
214
|
+
return "Groq"
|
|
215
|
+
case ModelProviderName.ollama:
|
|
216
|
+
return "Ollama"
|
|
217
|
+
case ModelProviderName.openai:
|
|
218
|
+
return "OpenAI"
|
|
219
|
+
case ModelProviderName.kiln_fine_tune:
|
|
220
|
+
return "Fine Tuned Models"
|
|
221
|
+
case ModelProviderName.fireworks_ai:
|
|
222
|
+
return "Fireworks AI"
|
|
223
|
+
case _:
|
|
224
|
+
# triggers pyright warning if I miss a case
|
|
225
|
+
raise_exhaustive_error(enum_id)
|
|
226
|
+
|
|
227
|
+
return "Unknown provider: " + id
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def provider_options_for_custom_model(
|
|
231
|
+
model_name: str, provider_name: str
|
|
232
|
+
) -> Dict[str, str]:
|
|
233
|
+
"""
|
|
234
|
+
Generated model provider options for a custom model. Each has their own format/options.
|
|
235
|
+
"""
|
|
236
|
+
if provider_name not in ModelProviderName.__members__:
|
|
237
|
+
raise ValueError(f"Invalid provider name: {provider_name}")
|
|
238
|
+
|
|
239
|
+
enum_id = ModelProviderName(provider_name)
|
|
240
|
+
match enum_id:
|
|
241
|
+
case ModelProviderName.amazon_bedrock:
|
|
242
|
+
# us-west-2 is the only region consistently supported by Bedrock
|
|
243
|
+
return {"model": model_name, "region_name": "us-west-2"}
|
|
244
|
+
case (
|
|
245
|
+
ModelProviderName.openai
|
|
246
|
+
| ModelProviderName.ollama
|
|
247
|
+
| ModelProviderName.fireworks_ai
|
|
248
|
+
| ModelProviderName.openrouter
|
|
249
|
+
| ModelProviderName.groq
|
|
250
|
+
):
|
|
251
|
+
return {"model": model_name}
|
|
252
|
+
case ModelProviderName.kiln_fine_tune:
|
|
253
|
+
raise ValueError(
|
|
254
|
+
"Fine tuned models should populate provider options via another path"
|
|
255
|
+
)
|
|
256
|
+
case _:
|
|
257
|
+
# triggers pyright warning if I miss a case
|
|
258
|
+
raise_exhaustive_error(enum_id)
|
|
259
|
+
|
|
260
|
+
# Won't reach this, type checking will catch missed values
|
|
261
|
+
return {"model": model_name}
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def raise_exhaustive_error(value: NoReturn) -> NoReturn:
|
|
265
|
+
raise ValueError(f"Unhandled enum value: {value}")
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
@dataclass
|
|
269
|
+
class ModelProviderWarning:
|
|
270
|
+
required_config_keys: List[str]
|
|
271
|
+
message: str
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
|
|
275
|
+
ModelProviderName.amazon_bedrock: ModelProviderWarning(
|
|
276
|
+
required_config_keys=["bedrock_access_key", "bedrock_secret_key"],
|
|
277
|
+
message="Attempted to use Amazon Bedrock without an access key and secret set. \nGet your keys from https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/overview",
|
|
278
|
+
),
|
|
279
|
+
ModelProviderName.openrouter: ModelProviderWarning(
|
|
280
|
+
required_config_keys=["open_router_api_key"],
|
|
281
|
+
message="Attempted to use OpenRouter without an API key set. \nGet your API key from https://openrouter.ai/settings/keys",
|
|
282
|
+
),
|
|
283
|
+
ModelProviderName.groq: ModelProviderWarning(
|
|
284
|
+
required_config_keys=["groq_api_key"],
|
|
285
|
+
message="Attempted to use Groq without an API key set. \nGet your API key from https://console.groq.com/keys",
|
|
286
|
+
),
|
|
287
|
+
ModelProviderName.openai: ModelProviderWarning(
|
|
288
|
+
required_config_keys=["open_ai_api_key"],
|
|
289
|
+
message="Attempted to use OpenAI without an API key set. \nGet your API key from https://platform.openai.com/account/api-keys",
|
|
290
|
+
),
|
|
291
|
+
ModelProviderName.fireworks_ai: ModelProviderWarning(
|
|
292
|
+
required_config_keys=["fireworks_api_key", "fireworks_account_id"],
|
|
293
|
+
message="Attempted to use Fireworks without an API key and account ID set. \nGet your API key from https://fireworks.ai/account/api-keys and your account ID from https://fireworks.ai/account/profile",
|
|
294
|
+
),
|
|
295
|
+
}
|
|
@@ -5,10 +5,9 @@ from unittest.mock import AsyncMock, patch
|
|
|
5
5
|
import pytest
|
|
6
6
|
from pydantic import ValidationError
|
|
7
7
|
|
|
8
|
+
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
8
9
|
from kiln_ai.adapters.base_adapter import RunOutput
|
|
9
|
-
from kiln_ai.adapters.langchain_adapters import
|
|
10
|
-
LangChainPromptAdapter,
|
|
11
|
-
)
|
|
10
|
+
from kiln_ai.adapters.langchain_adapters import LangchainAdapter
|
|
12
11
|
from kiln_ai.adapters.repair.repair_task import (
|
|
13
12
|
RepairTaskInput,
|
|
14
13
|
RepairTaskRun,
|
|
@@ -60,7 +59,7 @@ json_joke_schema = """{
|
|
|
60
59
|
|
|
61
60
|
@pytest.fixture
|
|
62
61
|
def sample_task(tmp_path):
|
|
63
|
-
task_path = tmp_path / "task.
|
|
62
|
+
task_path = tmp_path / "task.kiln"
|
|
64
63
|
task = Task(
|
|
65
64
|
name="Joke Generator",
|
|
66
65
|
path=task_path,
|
|
@@ -190,9 +189,7 @@ async def test_live_run(sample_task, sample_task_run, sample_repair_data):
|
|
|
190
189
|
repair_task_input = RepairTaskRun.build_repair_task_input(**sample_repair_data)
|
|
191
190
|
assert isinstance(repair_task_input, RepairTaskInput)
|
|
192
191
|
|
|
193
|
-
adapter =
|
|
194
|
-
repair_task, model_name="llama_3_1_8b", provider="groq"
|
|
195
|
-
)
|
|
192
|
+
adapter = adapter_for_task(repair_task, model_name="llama_3_1_8b", provider="groq")
|
|
196
193
|
|
|
197
194
|
run = await adapter.invoke(repair_task_input.model_dump())
|
|
198
195
|
assert run is not None
|
|
@@ -220,14 +217,12 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
|
|
|
220
217
|
"rating": 8,
|
|
221
218
|
}
|
|
222
219
|
|
|
223
|
-
with patch.object(
|
|
224
|
-
LangChainPromptAdapter, "_run", new_callable=AsyncMock
|
|
225
|
-
) as mock_run:
|
|
220
|
+
with patch.object(LangchainAdapter, "_run", new_callable=AsyncMock) as mock_run:
|
|
226
221
|
mock_run.return_value = RunOutput(
|
|
227
222
|
output=mocked_output, intermediate_outputs=None
|
|
228
223
|
)
|
|
229
224
|
|
|
230
|
-
adapter =
|
|
225
|
+
adapter = adapter_for_task(
|
|
231
226
|
repair_task, model_name="llama_3_1_8b", provider="groq"
|
|
232
227
|
)
|
|
233
228
|
|
|
@@ -3,16 +3,17 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
3
3
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
4
4
|
from langchain_groq import ChatGroq
|
|
5
5
|
|
|
6
|
-
from kiln_ai.adapters.langchain_adapters import
|
|
6
|
+
from kiln_ai.adapters.langchain_adapters import (
|
|
7
|
+
LangchainAdapter,
|
|
8
|
+
get_structured_output_options,
|
|
9
|
+
)
|
|
7
10
|
from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder
|
|
8
11
|
from kiln_ai.adapters.test_prompt_adaptors import build_test_task
|
|
9
12
|
|
|
10
13
|
|
|
11
14
|
def test_langchain_adapter_munge_response(tmp_path):
|
|
12
15
|
task = build_test_task(tmp_path)
|
|
13
|
-
lca =
|
|
14
|
-
kiln_task=task, model_name="llama_3_1_8b", provider="ollama"
|
|
15
|
-
)
|
|
16
|
+
lca = LangchainAdapter(kiln_task=task, model_name="llama_3_1_8b", provider="ollama")
|
|
16
17
|
# Mistral Large tool calling format is a bit different
|
|
17
18
|
response = {
|
|
18
19
|
"name": "task_response",
|
|
@@ -35,7 +36,7 @@ def test_langchain_adapter_infer_model_name(tmp_path):
|
|
|
35
36
|
task = build_test_task(tmp_path)
|
|
36
37
|
custom = ChatGroq(model="llama-3.1-8b-instant", groq_api_key="test")
|
|
37
38
|
|
|
38
|
-
lca =
|
|
39
|
+
lca = LangchainAdapter(kiln_task=task, custom_model=custom)
|
|
39
40
|
|
|
40
41
|
model_info = lca.adapter_info()
|
|
41
42
|
assert model_info.model_name == "custom.langchain:llama-3.1-8b-instant"
|
|
@@ -45,9 +46,7 @@ def test_langchain_adapter_infer_model_name(tmp_path):
|
|
|
45
46
|
def test_langchain_adapter_info(tmp_path):
|
|
46
47
|
task = build_test_task(tmp_path)
|
|
47
48
|
|
|
48
|
-
lca =
|
|
49
|
-
kiln_task=task, model_name="llama_3_1_8b", provider="ollama"
|
|
50
|
-
)
|
|
49
|
+
lca = LangchainAdapter(kiln_task=task, model_name="llama_3_1_8b", provider="ollama")
|
|
51
50
|
|
|
52
51
|
model_info = lca.adapter_info()
|
|
53
52
|
assert model_info.adapter_name == "kiln_langchain_adapter"
|
|
@@ -60,7 +59,7 @@ async def test_langchain_adapter_with_cot(tmp_path):
|
|
|
60
59
|
task.output_json_schema = (
|
|
61
60
|
'{"type": "object", "properties": {"count": {"type": "integer"}}}'
|
|
62
61
|
)
|
|
63
|
-
lca =
|
|
62
|
+
lca = LangchainAdapter(
|
|
64
63
|
kiln_task=task,
|
|
65
64
|
model_name="llama_3_1_8b",
|
|
66
65
|
provider="ollama",
|
|
@@ -69,13 +68,13 @@ async def test_langchain_adapter_with_cot(tmp_path):
|
|
|
69
68
|
|
|
70
69
|
# Mock the base model and its invoke method
|
|
71
70
|
mock_base_model = MagicMock()
|
|
72
|
-
mock_base_model.
|
|
73
|
-
content="Chain of thought reasoning..."
|
|
71
|
+
mock_base_model.ainvoke = AsyncMock(
|
|
72
|
+
return_value=AIMessage(content="Chain of thought reasoning...")
|
|
74
73
|
)
|
|
75
74
|
|
|
76
75
|
# Create a separate mock for self.model()
|
|
77
76
|
mock_model_instance = MagicMock()
|
|
78
|
-
mock_model_instance.
|
|
77
|
+
mock_model_instance.ainvoke = AsyncMock(return_value={"parsed": {"count": 1}})
|
|
79
78
|
|
|
80
79
|
# Mock the langchain_model_from function to return the base model
|
|
81
80
|
mock_model_from = AsyncMock(return_value=mock_base_model)
|
|
@@ -85,14 +84,14 @@ async def test_langchain_adapter_with_cot(tmp_path):
|
|
|
85
84
|
patch(
|
|
86
85
|
"kiln_ai.adapters.langchain_adapters.langchain_model_from", mock_model_from
|
|
87
86
|
),
|
|
88
|
-
patch.object(
|
|
87
|
+
patch.object(LangchainAdapter, "model", return_value=mock_model_instance),
|
|
89
88
|
):
|
|
90
89
|
response = await lca._run("test input")
|
|
91
90
|
|
|
92
91
|
# First 3 messages are the same for both calls
|
|
93
92
|
for invoke_args in [
|
|
94
|
-
mock_base_model.
|
|
95
|
-
mock_model_instance.
|
|
93
|
+
mock_base_model.ainvoke.call_args[0][0],
|
|
94
|
+
mock_model_instance.ainvoke.call_args[0][0],
|
|
96
95
|
]:
|
|
97
96
|
assert isinstance(
|
|
98
97
|
invoke_args[0], SystemMessage
|
|
@@ -107,11 +106,11 @@ async def test_langchain_adapter_with_cot(tmp_path):
|
|
|
107
106
|
assert "step by step" in invoke_args[2].content
|
|
108
107
|
|
|
109
108
|
# the COT should only have 3 messages
|
|
110
|
-
assert len(mock_base_model.
|
|
111
|
-
assert len(mock_model_instance.
|
|
109
|
+
assert len(mock_base_model.ainvoke.call_args[0][0]) == 3
|
|
110
|
+
assert len(mock_model_instance.ainvoke.call_args[0][0]) == 5
|
|
112
111
|
|
|
113
112
|
# the final response should have the COT content and the final instructions
|
|
114
|
-
invoke_args = mock_model_instance.
|
|
113
|
+
invoke_args = mock_model_instance.ainvoke.call_args[0][0]
|
|
115
114
|
assert isinstance(invoke_args[3], AIMessage)
|
|
116
115
|
assert "Chain of thought reasoning..." in invoke_args[3].content
|
|
117
116
|
assert isinstance(invoke_args[4], SystemMessage)
|
|
@@ -122,3 +121,32 @@ async def test_langchain_adapter_with_cot(tmp_path):
|
|
|
122
121
|
== "Chain of thought reasoning..."
|
|
123
122
|
)
|
|
124
123
|
assert response.output == {"count": 1}
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
async def test_get_structured_output_options():
|
|
127
|
+
# Mock the provider response
|
|
128
|
+
mock_provider = MagicMock()
|
|
129
|
+
mock_provider.adapter_options = {
|
|
130
|
+
"langchain": {
|
|
131
|
+
"with_structured_output_options": {
|
|
132
|
+
"force_json_response": True,
|
|
133
|
+
"max_retries": 3,
|
|
134
|
+
}
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
# Test with provider that has options
|
|
139
|
+
with patch(
|
|
140
|
+
"kiln_ai.adapters.langchain_adapters.kiln_model_provider_from",
|
|
141
|
+
AsyncMock(return_value=mock_provider),
|
|
142
|
+
):
|
|
143
|
+
options = await get_structured_output_options("model_name", "provider")
|
|
144
|
+
assert options == {"force_json_response": True, "max_retries": 3}
|
|
145
|
+
|
|
146
|
+
# Test with provider that has no options
|
|
147
|
+
with patch(
|
|
148
|
+
"kiln_ai.adapters.langchain_adapters.kiln_model_provider_from",
|
|
149
|
+
AsyncMock(return_value=None),
|
|
150
|
+
):
|
|
151
|
+
options = await get_structured_output_options("model_name", "provider")
|
|
152
|
+
assert options == {}
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from kiln_ai.adapters.ollama_tools import (
|
|
4
|
+
OllamaConnection,
|
|
5
|
+
ollama_model_installed,
|
|
6
|
+
parse_ollama_tags,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def test_parse_ollama_tags_no_models():
|
|
11
|
+
json_response = '{"models":[{"name":"scosman_net","model":"scosman_net:latest"},{"name":"phi3.5:latest","model":"phi3.5:latest","modified_at":"2024-10-02T12:04:35.191519822-04:00","size":2176178843,"digest":"61819fb370a3c1a9be6694869331e5f85f867a079e9271d66cb223acb81d04ba","details":{"parent_model":"","format":"gguf","family":"phi3","families":["phi3"],"parameter_size":"3.8B","quantization_level":"Q4_0"}},{"name":"gemma2:2b","model":"gemma2:2b","modified_at":"2024-09-09T16:46:38.64348929-04:00","size":1629518495,"digest":"8ccf136fdd5298f3ffe2d69862750ea7fb56555fa4d5b18c04e3fa4d82ee09d7","details":{"parent_model":"","format":"gguf","family":"gemma2","families":["gemma2"],"parameter_size":"2.6B","quantization_level":"Q4_0"}},{"name":"llama3.1:latest","model":"llama3.1:latest","modified_at":"2024-09-01T17:19:43.481523695-04:00","size":4661230720,"digest":"f66fc8dc39ea206e03ff6764fcc696b1b4dfb693f0b6ef751731dd4e6269046e","details":{"parent_model":"","format":"gguf","family":"llama","families":["llama"],"parameter_size":"8.0B","quantization_level":"Q4_0"}}]}'
|
|
12
|
+
tags = json.loads(json_response)
|
|
13
|
+
print(json.dumps(tags, indent=2))
|
|
14
|
+
conn = parse_ollama_tags(tags)
|
|
15
|
+
assert "phi3.5:latest" in conn.supported_models
|
|
16
|
+
assert "gemma2:2b" in conn.supported_models
|
|
17
|
+
assert "llama3.1:latest" in conn.supported_models
|
|
18
|
+
assert "scosman_net:latest" in conn.untested_models
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def test_parse_ollama_tags_only_untested_models():
|
|
22
|
+
json_response = '{"models":[{"name":"scosman_net","model":"scosman_net:latest"}]}'
|
|
23
|
+
tags = json.loads(json_response)
|
|
24
|
+
conn = parse_ollama_tags(tags)
|
|
25
|
+
assert conn.supported_models == []
|
|
26
|
+
assert conn.untested_models == ["scosman_net:latest"]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def test_ollama_model_installed():
|
|
30
|
+
conn = OllamaConnection(
|
|
31
|
+
supported_models=["phi3.5:latest", "gemma2:2b", "llama3.1:latest"],
|
|
32
|
+
message="Connected",
|
|
33
|
+
untested_models=["scosman_net:latest"],
|
|
34
|
+
)
|
|
35
|
+
assert ollama_model_installed(conn, "phi3.5:latest")
|
|
36
|
+
assert ollama_model_installed(conn, "phi3.5")
|
|
37
|
+
assert ollama_model_installed(conn, "gemma2:2b")
|
|
38
|
+
assert ollama_model_installed(conn, "llama3.1:latest")
|
|
39
|
+
assert ollama_model_installed(conn, "llama3.1")
|
|
40
|
+
assert ollama_model_installed(conn, "scosman_net:latest")
|
|
41
|
+
assert ollama_model_installed(conn, "scosman_net")
|
|
42
|
+
assert not ollama_model_installed(conn, "unknown_model")
|
|
@@ -5,8 +5,10 @@ import pytest
|
|
|
5
5
|
from langchain_core.language_models.fake_chat_models import FakeListChatModel
|
|
6
6
|
|
|
7
7
|
import kiln_ai.datamodel as datamodel
|
|
8
|
-
from kiln_ai.adapters.
|
|
9
|
-
from kiln_ai.adapters.
|
|
8
|
+
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
9
|
+
from kiln_ai.adapters.langchain_adapters import LangchainAdapter
|
|
10
|
+
from kiln_ai.adapters.ml_model_list import built_in_models
|
|
11
|
+
from kiln_ai.adapters.ollama_tools import ollama_online
|
|
10
12
|
from kiln_ai.adapters.prompt_builders import (
|
|
11
13
|
BasePromptBuilder,
|
|
12
14
|
SimpleChainOfThoughtPromptBuilder,
|
|
@@ -106,7 +108,7 @@ async def test_amazon_bedrock(tmp_path):
|
|
|
106
108
|
async def test_mock(tmp_path):
|
|
107
109
|
task = build_test_task(tmp_path)
|
|
108
110
|
mockChatModel = FakeListChatModel(responses=["mock response"])
|
|
109
|
-
adapter =
|
|
111
|
+
adapter = LangchainAdapter(task, custom_model=mockChatModel)
|
|
110
112
|
run = await adapter.invoke("You are a mock, send me the response!")
|
|
111
113
|
assert "mock response" in run.output.output
|
|
112
114
|
|
|
@@ -114,7 +116,7 @@ async def test_mock(tmp_path):
|
|
|
114
116
|
async def test_mock_returning_run(tmp_path):
|
|
115
117
|
task = build_test_task(tmp_path)
|
|
116
118
|
mockChatModel = FakeListChatModel(responses=["mock response"])
|
|
117
|
-
adapter =
|
|
119
|
+
adapter = LangchainAdapter(task, custom_model=mockChatModel)
|
|
118
120
|
run = await adapter.invoke("You are a mock, send me the response!")
|
|
119
121
|
assert run.output.output == "mock response"
|
|
120
122
|
assert run is not None
|
|
@@ -192,7 +194,7 @@ async def run_simple_task(
|
|
|
192
194
|
provider: str,
|
|
193
195
|
prompt_builder: BasePromptBuilder | None = None,
|
|
194
196
|
) -> datamodel.TaskRun:
|
|
195
|
-
adapter =
|
|
197
|
+
adapter = adapter_for_task(
|
|
196
198
|
task, model_name=model_name, provider=provider, prompt_builder=prompt_builder
|
|
197
199
|
)
|
|
198
200
|
|