kiln-ai 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +19 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
- kiln_ai/adapters/fine_tune/__init__.py +14 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
- kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
- kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
- kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
- kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
- kiln_ai/adapters/langchain_adapters.py +103 -13
- kiln_ai/adapters/ml_model_list.py +239 -303
- kiln_ai/adapters/ollama_tools.py +115 -0
- kiln_ai/adapters/provider_tools.py +308 -0
- kiln_ai/adapters/repair/repair_task.py +4 -2
- kiln_ai/adapters/repair/test_repair_task.py +6 -11
- kiln_ai/adapters/test_langchain_adapter.py +229 -18
- kiln_ai/adapters/test_ollama_tools.py +42 -0
- kiln_ai/adapters/test_prompt_adaptors.py +7 -5
- kiln_ai/adapters/test_provider_tools.py +531 -0
- kiln_ai/adapters/test_structured_output.py +22 -43
- kiln_ai/datamodel/__init__.py +287 -24
- kiln_ai/datamodel/basemodel.py +122 -38
- kiln_ai/datamodel/model_cache.py +116 -0
- kiln_ai/datamodel/registry.py +31 -0
- kiln_ai/datamodel/test_basemodel.py +167 -4
- kiln_ai/datamodel/test_dataset_split.py +234 -0
- kiln_ai/datamodel/test_example_models.py +12 -0
- kiln_ai/datamodel/test_model_cache.py +244 -0
- kiln_ai/datamodel/test_models.py +215 -1
- kiln_ai/datamodel/test_registry.py +96 -0
- kiln_ai/utils/config.py +14 -1
- kiln_ai/utils/name_generator.py +125 -0
- kiln_ai/utils/test_name_geneator.py +47 -0
- kiln_ai-0.7.1.dist-info/METADATA +237 -0
- kiln_ai-0.7.1.dist-info/RECORD +58 -0
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/WHEEL +1 -1
- kiln_ai/adapters/test_ml_model_list.py +0 -181
- kiln_ai-0.6.1.dist-info/METADATA +0 -88
- kiln_ai-0.6.1.dist-info/RECORD +0 -37
- {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Any, List
|
|
3
|
+
|
|
4
|
+
import httpx
|
|
5
|
+
import requests
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
|
|
8
|
+
from kiln_ai.adapters.ml_model_list import ModelProviderName, built_in_models
|
|
9
|
+
from kiln_ai.utils.config import Config
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def ollama_base_url() -> str:
|
|
13
|
+
"""
|
|
14
|
+
Gets the base URL for Ollama API connections.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
The base URL to use for Ollama API calls, using environment variable if set
|
|
18
|
+
or falling back to localhost default
|
|
19
|
+
"""
|
|
20
|
+
config_base_url = Config.shared().ollama_base_url
|
|
21
|
+
if config_base_url:
|
|
22
|
+
return config_base_url
|
|
23
|
+
return "http://localhost:11434"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
async def ollama_online() -> bool:
|
|
27
|
+
"""
|
|
28
|
+
Checks if the Ollama service is available and responding.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
True if Ollama is available and responding, False otherwise
|
|
32
|
+
"""
|
|
33
|
+
try:
|
|
34
|
+
httpx.get(ollama_base_url() + "/api/tags")
|
|
35
|
+
except httpx.RequestError:
|
|
36
|
+
return False
|
|
37
|
+
return True
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class OllamaConnection(BaseModel):
|
|
41
|
+
message: str
|
|
42
|
+
supported_models: List[str]
|
|
43
|
+
untested_models: List[str] = Field(default_factory=list)
|
|
44
|
+
|
|
45
|
+
def all_models(self) -> List[str]:
|
|
46
|
+
return self.supported_models + self.untested_models
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# Parse the Ollama /api/tags response
|
|
50
|
+
def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
|
|
51
|
+
# Build a list of models we support for Ollama from the built-in model list
|
|
52
|
+
supported_ollama_models = [
|
|
53
|
+
provider.provider_options["model"]
|
|
54
|
+
for model in built_in_models
|
|
55
|
+
for provider in model.providers
|
|
56
|
+
if provider.name == ModelProviderName.ollama
|
|
57
|
+
]
|
|
58
|
+
# Append model_aliases to supported_ollama_models
|
|
59
|
+
supported_ollama_models.extend(
|
|
60
|
+
[
|
|
61
|
+
alias
|
|
62
|
+
for model in built_in_models
|
|
63
|
+
for provider in model.providers
|
|
64
|
+
for alias in provider.provider_options.get("model_aliases", [])
|
|
65
|
+
]
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
if "models" in tags:
|
|
69
|
+
models = tags["models"]
|
|
70
|
+
if isinstance(models, list):
|
|
71
|
+
model_names = [model["model"] for model in models]
|
|
72
|
+
available_supported_models = []
|
|
73
|
+
untested_models = []
|
|
74
|
+
supported_models_latest_aliases = [
|
|
75
|
+
f"{m}:latest" for m in supported_ollama_models
|
|
76
|
+
]
|
|
77
|
+
for model in model_names:
|
|
78
|
+
if (
|
|
79
|
+
model in supported_ollama_models
|
|
80
|
+
or model in supported_models_latest_aliases
|
|
81
|
+
):
|
|
82
|
+
available_supported_models.append(model)
|
|
83
|
+
else:
|
|
84
|
+
untested_models.append(model)
|
|
85
|
+
|
|
86
|
+
if available_supported_models or untested_models:
|
|
87
|
+
return OllamaConnection(
|
|
88
|
+
message="Ollama connected",
|
|
89
|
+
supported_models=available_supported_models,
|
|
90
|
+
untested_models=untested_models,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
return OllamaConnection(
|
|
94
|
+
message="Ollama is running, but no supported models are installed. Install one or more supported model, like 'ollama pull phi3.5'.",
|
|
95
|
+
supported_models=[],
|
|
96
|
+
untested_models=[],
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
async def get_ollama_connection() -> OllamaConnection | None:
|
|
101
|
+
"""
|
|
102
|
+
Gets the connection status for Ollama.
|
|
103
|
+
"""
|
|
104
|
+
try:
|
|
105
|
+
tags = requests.get(ollama_base_url() + "/api/tags", timeout=5).json()
|
|
106
|
+
|
|
107
|
+
except Exception:
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
return parse_ollama_tags(tags)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def ollama_model_installed(conn: OllamaConnection, model_name: str) -> bool:
|
|
114
|
+
all_models = conn.all_models()
|
|
115
|
+
return model_name in all_models or f"{model_name}:latest" in all_models
|
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Dict, List, NoReturn
|
|
3
|
+
|
|
4
|
+
from kiln_ai.adapters.ml_model_list import (
|
|
5
|
+
KilnModel,
|
|
6
|
+
KilnModelProvider,
|
|
7
|
+
ModelName,
|
|
8
|
+
ModelProviderName,
|
|
9
|
+
built_in_models,
|
|
10
|
+
)
|
|
11
|
+
from kiln_ai.adapters.ollama_tools import (
|
|
12
|
+
get_ollama_connection,
|
|
13
|
+
)
|
|
14
|
+
from kiln_ai.datamodel import Finetune, Task
|
|
15
|
+
from kiln_ai.datamodel.registry import project_from_id
|
|
16
|
+
|
|
17
|
+
from ..utils.config import Config
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
async def provider_enabled(provider_name: ModelProviderName) -> bool:
|
|
21
|
+
if provider_name == ModelProviderName.ollama:
|
|
22
|
+
try:
|
|
23
|
+
conn = await get_ollama_connection()
|
|
24
|
+
return conn is not None and (
|
|
25
|
+
len(conn.supported_models) > 0 or len(conn.untested_models) > 0
|
|
26
|
+
)
|
|
27
|
+
except Exception:
|
|
28
|
+
return False
|
|
29
|
+
|
|
30
|
+
provider_warning = provider_warnings.get(provider_name)
|
|
31
|
+
if provider_warning is None:
|
|
32
|
+
return False
|
|
33
|
+
for required_key in provider_warning.required_config_keys:
|
|
34
|
+
if get_config_value(required_key) is None:
|
|
35
|
+
return False
|
|
36
|
+
return True
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_config_value(key: str):
|
|
40
|
+
try:
|
|
41
|
+
return Config.shared().__getattr__(key)
|
|
42
|
+
except AttributeError:
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def check_provider_warnings(provider_name: ModelProviderName):
|
|
47
|
+
"""
|
|
48
|
+
Validates that required configuration is present for a given provider.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
provider_name: The provider to check
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
ValueError: If required configuration keys are missing
|
|
55
|
+
"""
|
|
56
|
+
warning_check = provider_warnings.get(provider_name)
|
|
57
|
+
if warning_check is None:
|
|
58
|
+
return
|
|
59
|
+
for key in warning_check.required_config_keys:
|
|
60
|
+
if get_config_value(key) is None:
|
|
61
|
+
raise ValueError(warning_check.message)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
async def builtin_model_from(
|
|
65
|
+
name: str, provider_name: str | None = None
|
|
66
|
+
) -> KilnModelProvider | None:
|
|
67
|
+
"""
|
|
68
|
+
Gets a model and provider from the built-in list of models.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
name: The name of the model to get
|
|
72
|
+
provider_name: Optional specific provider to use (defaults to first available)
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
A tuple of (provider, model)
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
ValueError: If the model or provider is not found, or if the provider is misconfigured
|
|
79
|
+
"""
|
|
80
|
+
if name not in ModelName.__members__:
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
# Select the model from built_in_models using the name
|
|
84
|
+
model = next(filter(lambda m: m.name == name, built_in_models))
|
|
85
|
+
if model is None:
|
|
86
|
+
raise ValueError(f"Model {name} not found")
|
|
87
|
+
|
|
88
|
+
# If a provider is provided, select the provider from the model's provider_config
|
|
89
|
+
provider: KilnModelProvider | None = None
|
|
90
|
+
if model.providers is None or len(model.providers) == 0:
|
|
91
|
+
raise ValueError(f"Model {name} has no providers")
|
|
92
|
+
elif provider_name is None:
|
|
93
|
+
provider = model.providers[0]
|
|
94
|
+
else:
|
|
95
|
+
provider = next(
|
|
96
|
+
filter(lambda p: p.name == provider_name, model.providers), None
|
|
97
|
+
)
|
|
98
|
+
if provider is None:
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
check_provider_warnings(provider.name)
|
|
102
|
+
return provider
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
async def kiln_model_provider_from(
|
|
106
|
+
name: str, provider_name: str | None = None
|
|
107
|
+
) -> KilnModelProvider:
|
|
108
|
+
if provider_name == ModelProviderName.kiln_fine_tune:
|
|
109
|
+
return finetune_provider_model(name)
|
|
110
|
+
|
|
111
|
+
built_in_model = await builtin_model_from(name, provider_name)
|
|
112
|
+
if built_in_model:
|
|
113
|
+
return built_in_model
|
|
114
|
+
|
|
115
|
+
# For custom registry, get the provider name and model name from the model id
|
|
116
|
+
if provider_name == ModelProviderName.kiln_custom_registry:
|
|
117
|
+
provider_name = name.split("::", 1)[0]
|
|
118
|
+
name = name.split("::", 1)[1]
|
|
119
|
+
|
|
120
|
+
# Custom/untested model. Set untested, and build a ModelProvider at runtime
|
|
121
|
+
if provider_name is None:
|
|
122
|
+
raise ValueError("Provider name is required for custom models")
|
|
123
|
+
if provider_name not in ModelProviderName.__members__:
|
|
124
|
+
raise ValueError(f"Invalid provider name: {provider_name}")
|
|
125
|
+
provider = ModelProviderName(provider_name)
|
|
126
|
+
check_provider_warnings(provider)
|
|
127
|
+
return KilnModelProvider(
|
|
128
|
+
name=provider,
|
|
129
|
+
supports_structured_output=False,
|
|
130
|
+
supports_data_gen=False,
|
|
131
|
+
untested_model=True,
|
|
132
|
+
provider_options=provider_options_for_custom_model(name, provider_name),
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
finetune_cache: dict[str, KilnModelProvider] = {}
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def finetune_provider_model(
|
|
140
|
+
model_id: str,
|
|
141
|
+
) -> KilnModelProvider:
|
|
142
|
+
if model_id in finetune_cache:
|
|
143
|
+
return finetune_cache[model_id]
|
|
144
|
+
|
|
145
|
+
try:
|
|
146
|
+
project_id, task_id, fine_tune_id = model_id.split("::")
|
|
147
|
+
except Exception:
|
|
148
|
+
raise ValueError(f"Invalid fine tune ID: {model_id}")
|
|
149
|
+
project = project_from_id(project_id)
|
|
150
|
+
if project is None:
|
|
151
|
+
raise ValueError(f"Project {project_id} not found")
|
|
152
|
+
task = Task.from_id_and_parent_path(task_id, project.path)
|
|
153
|
+
if task is None:
|
|
154
|
+
raise ValueError(f"Task {task_id} not found")
|
|
155
|
+
fine_tune = Finetune.from_id_and_parent_path(fine_tune_id, task.path)
|
|
156
|
+
if fine_tune is None:
|
|
157
|
+
raise ValueError(f"Fine tune {fine_tune_id} not found")
|
|
158
|
+
if fine_tune.fine_tune_model_id is None:
|
|
159
|
+
raise ValueError(
|
|
160
|
+
f"Fine tune {fine_tune_id} not completed. Refresh it's status in the fine-tune tab."
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
provider = ModelProviderName[fine_tune.provider]
|
|
164
|
+
model_provider = KilnModelProvider(
|
|
165
|
+
name=provider,
|
|
166
|
+
provider_options={
|
|
167
|
+
"model": fine_tune.fine_tune_model_id,
|
|
168
|
+
},
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# TODO: Don't love this abstraction/logic.
|
|
172
|
+
if fine_tune.provider == ModelProviderName.fireworks_ai:
|
|
173
|
+
# Fireworks finetunes are trained with json, not tool calling (which is LC default format)
|
|
174
|
+
model_provider.adapter_options = {
|
|
175
|
+
"langchain": {
|
|
176
|
+
"with_structured_output_options": {
|
|
177
|
+
"method": "json_mode",
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
finetune_cache[model_id] = model_provider
|
|
183
|
+
return model_provider
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def get_model_and_provider(
|
|
187
|
+
model_name: str, provider_name: str
|
|
188
|
+
) -> tuple[KilnModel | None, KilnModelProvider | None]:
|
|
189
|
+
model = next(filter(lambda m: m.name == model_name, built_in_models), None)
|
|
190
|
+
if model is None:
|
|
191
|
+
return None, None
|
|
192
|
+
provider = next(filter(lambda p: p.name == provider_name, model.providers), None)
|
|
193
|
+
# all or nothing
|
|
194
|
+
if provider is None or model is None:
|
|
195
|
+
return None, None
|
|
196
|
+
return model, provider
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def provider_name_from_id(id: str) -> str:
|
|
200
|
+
"""
|
|
201
|
+
Converts a provider ID to its human-readable name.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
id: The provider identifier string
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
The human-readable name of the provider
|
|
208
|
+
|
|
209
|
+
Raises:
|
|
210
|
+
ValueError: If the provider ID is invalid or unhandled
|
|
211
|
+
"""
|
|
212
|
+
if id in ModelProviderName.__members__:
|
|
213
|
+
enum_id = ModelProviderName(id)
|
|
214
|
+
match enum_id:
|
|
215
|
+
case ModelProviderName.amazon_bedrock:
|
|
216
|
+
return "Amazon Bedrock"
|
|
217
|
+
case ModelProviderName.openrouter:
|
|
218
|
+
return "OpenRouter"
|
|
219
|
+
case ModelProviderName.groq:
|
|
220
|
+
return "Groq"
|
|
221
|
+
case ModelProviderName.ollama:
|
|
222
|
+
return "Ollama"
|
|
223
|
+
case ModelProviderName.openai:
|
|
224
|
+
return "OpenAI"
|
|
225
|
+
case ModelProviderName.kiln_fine_tune:
|
|
226
|
+
return "Fine Tuned Models"
|
|
227
|
+
case ModelProviderName.fireworks_ai:
|
|
228
|
+
return "Fireworks AI"
|
|
229
|
+
case ModelProviderName.kiln_custom_registry:
|
|
230
|
+
return "Custom Models"
|
|
231
|
+
case _:
|
|
232
|
+
# triggers pyright warning if I miss a case
|
|
233
|
+
raise_exhaustive_error(enum_id)
|
|
234
|
+
|
|
235
|
+
return "Unknown provider: " + id
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def provider_options_for_custom_model(
|
|
239
|
+
model_name: str, provider_name: str
|
|
240
|
+
) -> Dict[str, str]:
|
|
241
|
+
"""
|
|
242
|
+
Generated model provider options for a custom model. Each has their own format/options.
|
|
243
|
+
"""
|
|
244
|
+
|
|
245
|
+
if provider_name not in ModelProviderName.__members__:
|
|
246
|
+
raise ValueError(f"Invalid provider name: {provider_name}")
|
|
247
|
+
|
|
248
|
+
enum_id = ModelProviderName(provider_name)
|
|
249
|
+
match enum_id:
|
|
250
|
+
case ModelProviderName.amazon_bedrock:
|
|
251
|
+
# us-west-2 is the only region consistently supported by Bedrock
|
|
252
|
+
return {"model": model_name, "region_name": "us-west-2"}
|
|
253
|
+
case (
|
|
254
|
+
ModelProviderName.openai
|
|
255
|
+
| ModelProviderName.ollama
|
|
256
|
+
| ModelProviderName.fireworks_ai
|
|
257
|
+
| ModelProviderName.openrouter
|
|
258
|
+
| ModelProviderName.groq
|
|
259
|
+
):
|
|
260
|
+
return {"model": model_name}
|
|
261
|
+
case ModelProviderName.kiln_custom_registry:
|
|
262
|
+
raise ValueError(
|
|
263
|
+
"Custom models from registry should be parsed into provider/model before calling this."
|
|
264
|
+
)
|
|
265
|
+
case ModelProviderName.kiln_fine_tune:
|
|
266
|
+
raise ValueError(
|
|
267
|
+
"Fine tuned models should populate provider options via another path"
|
|
268
|
+
)
|
|
269
|
+
case _:
|
|
270
|
+
# triggers pyright warning if I miss a case
|
|
271
|
+
raise_exhaustive_error(enum_id)
|
|
272
|
+
|
|
273
|
+
# Won't reach this, type checking will catch missed values
|
|
274
|
+
return {"model": model_name}
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def raise_exhaustive_error(value: NoReturn) -> NoReturn:
|
|
278
|
+
raise ValueError(f"Unhandled enum value: {value}")
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
@dataclass
|
|
282
|
+
class ModelProviderWarning:
|
|
283
|
+
required_config_keys: List[str]
|
|
284
|
+
message: str
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
|
|
288
|
+
ModelProviderName.amazon_bedrock: ModelProviderWarning(
|
|
289
|
+
required_config_keys=["bedrock_access_key", "bedrock_secret_key"],
|
|
290
|
+
message="Attempted to use Amazon Bedrock without an access key and secret set. \nGet your keys from https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/overview",
|
|
291
|
+
),
|
|
292
|
+
ModelProviderName.openrouter: ModelProviderWarning(
|
|
293
|
+
required_config_keys=["open_router_api_key"],
|
|
294
|
+
message="Attempted to use OpenRouter without an API key set. \nGet your API key from https://openrouter.ai/settings/keys",
|
|
295
|
+
),
|
|
296
|
+
ModelProviderName.groq: ModelProviderWarning(
|
|
297
|
+
required_config_keys=["groq_api_key"],
|
|
298
|
+
message="Attempted to use Groq without an API key set. \nGet your API key from https://console.groq.com/keys",
|
|
299
|
+
),
|
|
300
|
+
ModelProviderName.openai: ModelProviderWarning(
|
|
301
|
+
required_config_keys=["open_ai_api_key"],
|
|
302
|
+
message="Attempted to use OpenAI without an API key set. \nGet your API key from https://platform.openai.com/account/api-keys",
|
|
303
|
+
),
|
|
304
|
+
ModelProviderName.fireworks_ai: ModelProviderWarning(
|
|
305
|
+
required_config_keys=["fireworks_api_key", "fireworks_account_id"],
|
|
306
|
+
message="Attempted to use Fireworks without an API key and account ID set. \nGet your API key from https://fireworks.ai/account/api-keys and your account ID from https://fireworks.ai/account/profile",
|
|
307
|
+
),
|
|
308
|
+
}
|
|
@@ -43,8 +43,10 @@ feedback describing what should be improved. Your job is to understand the evalu
|
|
|
43
43
|
@classmethod
|
|
44
44
|
def _original_prompt(cls, run: TaskRun, task: Task) -> str:
|
|
45
45
|
prompt_builder_class: Type[BasePromptBuilder] | None = None
|
|
46
|
-
prompt_builder_name =
|
|
47
|
-
"prompt_builder_name", None
|
|
46
|
+
prompt_builder_name = (
|
|
47
|
+
run.output.source.properties.get("prompt_builder_name", None)
|
|
48
|
+
if run.output.source
|
|
49
|
+
else None
|
|
48
50
|
)
|
|
49
51
|
if prompt_builder_name is not None and isinstance(prompt_builder_name, str):
|
|
50
52
|
prompt_builder_class = prompt_builder_registry.get(
|
|
@@ -5,10 +5,9 @@ from unittest.mock import AsyncMock, patch
|
|
|
5
5
|
import pytest
|
|
6
6
|
from pydantic import ValidationError
|
|
7
7
|
|
|
8
|
+
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
8
9
|
from kiln_ai.adapters.base_adapter import RunOutput
|
|
9
|
-
from kiln_ai.adapters.langchain_adapters import
|
|
10
|
-
LangChainPromptAdapter,
|
|
11
|
-
)
|
|
10
|
+
from kiln_ai.adapters.langchain_adapters import LangchainAdapter
|
|
12
11
|
from kiln_ai.adapters.repair.repair_task import (
|
|
13
12
|
RepairTaskInput,
|
|
14
13
|
RepairTaskRun,
|
|
@@ -60,7 +59,7 @@ json_joke_schema = """{
|
|
|
60
59
|
|
|
61
60
|
@pytest.fixture
|
|
62
61
|
def sample_task(tmp_path):
|
|
63
|
-
task_path = tmp_path / "task.
|
|
62
|
+
task_path = tmp_path / "task.kiln"
|
|
64
63
|
task = Task(
|
|
65
64
|
name="Joke Generator",
|
|
66
65
|
path=task_path,
|
|
@@ -190,9 +189,7 @@ async def test_live_run(sample_task, sample_task_run, sample_repair_data):
|
|
|
190
189
|
repair_task_input = RepairTaskRun.build_repair_task_input(**sample_repair_data)
|
|
191
190
|
assert isinstance(repair_task_input, RepairTaskInput)
|
|
192
191
|
|
|
193
|
-
adapter =
|
|
194
|
-
repair_task, model_name="llama_3_1_8b", provider="groq"
|
|
195
|
-
)
|
|
192
|
+
adapter = adapter_for_task(repair_task, model_name="llama_3_1_8b", provider="groq")
|
|
196
193
|
|
|
197
194
|
run = await adapter.invoke(repair_task_input.model_dump())
|
|
198
195
|
assert run is not None
|
|
@@ -220,14 +217,12 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
|
|
|
220
217
|
"rating": 8,
|
|
221
218
|
}
|
|
222
219
|
|
|
223
|
-
with patch.object(
|
|
224
|
-
LangChainPromptAdapter, "_run", new_callable=AsyncMock
|
|
225
|
-
) as mock_run:
|
|
220
|
+
with patch.object(LangchainAdapter, "_run", new_callable=AsyncMock) as mock_run:
|
|
226
221
|
mock_run.return_value = RunOutput(
|
|
227
222
|
output=mocked_output, intermediate_outputs=None
|
|
228
223
|
)
|
|
229
224
|
|
|
230
|
-
adapter =
|
|
225
|
+
adapter = adapter_for_task(
|
|
231
226
|
repair_task, model_name="llama_3_1_8b", provider="groq"
|
|
232
227
|
)
|
|
233
228
|
|