kiln-ai 0.12.0__py3-none-any.whl → 0.13.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +4 -0
- kiln_ai/adapters/adapter_registry.py +157 -28
- kiln_ai/adapters/eval/__init__.py +28 -0
- kiln_ai/adapters/eval/eval_runner.py +4 -1
- kiln_ai/adapters/eval/g_eval.py +19 -3
- kiln_ai/adapters/eval/test_base_eval.py +1 -0
- kiln_ai/adapters/eval/test_eval_runner.py +1 -0
- kiln_ai/adapters/eval/test_g_eval.py +13 -7
- kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +8 -1
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_together_finetune.py +533 -0
- kiln_ai/adapters/fine_tune/together_finetune.py +327 -0
- kiln_ai/adapters/ml_model_list.py +638 -155
- kiln_ai/adapters/model_adapters/__init__.py +2 -4
- kiln_ai/adapters/model_adapters/base_adapter.py +14 -11
- kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
- kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
- kiln_ai/adapters/model_adapters/test_structured_output.py +23 -5
- kiln_ai/adapters/ollama_tools.py +3 -2
- kiln_ai/adapters/parsers/r1_parser.py +19 -14
- kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
- kiln_ai/adapters/provider_tools.py +52 -60
- kiln_ai/adapters/repair/test_repair_task.py +3 -3
- kiln_ai/adapters/run_output.py +1 -1
- kiln_ai/adapters/test_adapter_registry.py +17 -20
- kiln_ai/adapters/test_generate_docs.py +2 -2
- kiln_ai/adapters/test_prompt_adaptors.py +30 -19
- kiln_ai/adapters/test_provider_tools.py +27 -82
- kiln_ai/datamodel/basemodel.py +2 -0
- kiln_ai/datamodel/datamodel_enums.py +2 -0
- kiln_ai/datamodel/json_schema.py +1 -1
- kiln_ai/datamodel/task_output.py +13 -6
- kiln_ai/datamodel/test_basemodel.py +9 -0
- kiln_ai/datamodel/test_datasource.py +19 -0
- kiln_ai/utils/config.py +46 -0
- kiln_ai/utils/dataset_import.py +232 -0
- kiln_ai/utils/test_dataset_import.py +596 -0
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/METADATA +51 -7
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/RECORD +44 -41
- kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -309
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -10
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -289
- kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -343
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -216
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/WHEEL +0 -0
- {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/licenses/LICENSE.txt +0 -0
kiln_ai/adapters/__init__.py
CHANGED
|
@@ -12,10 +12,13 @@ The prompt_builders submodule contains classes that build prompts for use with t
|
|
|
12
12
|
The repair submodule contains an adapter for the repair task.
|
|
13
13
|
|
|
14
14
|
The parser submodule contains parsers for the output of the AI models.
|
|
15
|
+
|
|
16
|
+
The eval submodule contains the code for evaluating the performance of a model.
|
|
15
17
|
"""
|
|
16
18
|
|
|
17
19
|
from . import (
|
|
18
20
|
data_gen,
|
|
21
|
+
eval,
|
|
19
22
|
fine_tune,
|
|
20
23
|
ml_model_list,
|
|
21
24
|
model_adapters,
|
|
@@ -30,4 +33,5 @@ __all__ = [
|
|
|
30
33
|
"ml_model_list",
|
|
31
34
|
"prompt_builders",
|
|
32
35
|
"repair",
|
|
36
|
+
"eval",
|
|
33
37
|
]
|
|
@@ -3,12 +3,11 @@ from os import getenv
|
|
|
3
3
|
from kiln_ai import datamodel
|
|
4
4
|
from kiln_ai.adapters.ml_model_list import ModelProviderName
|
|
5
5
|
from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig, BaseAdapter
|
|
6
|
-
from kiln_ai.adapters.model_adapters.
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
OpenAICompatibleConfig,
|
|
6
|
+
from kiln_ai.adapters.model_adapters.litellm_adapter import (
|
|
7
|
+
LiteLlmAdapter,
|
|
8
|
+
LiteLlmConfig,
|
|
10
9
|
)
|
|
11
|
-
from kiln_ai.adapters.provider_tools import core_provider,
|
|
10
|
+
from kiln_ai.adapters.provider_tools import core_provider, lite_llm_config
|
|
12
11
|
from kiln_ai.datamodel import PromptId
|
|
13
12
|
from kiln_ai.utils.config import Config
|
|
14
13
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
@@ -26,50 +25,189 @@ def adapter_for_task(
|
|
|
26
25
|
|
|
27
26
|
match core_provider_name:
|
|
28
27
|
case ModelProviderName.openrouter:
|
|
29
|
-
return
|
|
28
|
+
return LiteLlmAdapter(
|
|
30
29
|
kiln_task=kiln_task,
|
|
31
|
-
config=
|
|
30
|
+
config=LiteLlmConfig(
|
|
31
|
+
model_name=model_name,
|
|
32
32
|
base_url=getenv("OPENROUTER_BASE_URL")
|
|
33
33
|
or "https://openrouter.ai/api/v1",
|
|
34
|
-
api_key=Config.shared().open_router_api_key,
|
|
35
|
-
model_name=model_name,
|
|
36
34
|
provider_name=provider,
|
|
37
35
|
default_headers={
|
|
38
36
|
"HTTP-Referer": "https://getkiln.ai/openrouter",
|
|
39
37
|
"X-Title": "KilnAI",
|
|
40
38
|
},
|
|
39
|
+
additional_body_options={
|
|
40
|
+
"api_key": Config.shared().open_router_api_key,
|
|
41
|
+
},
|
|
41
42
|
),
|
|
42
43
|
prompt_id=prompt_id,
|
|
43
44
|
base_adapter_config=base_adapter_config,
|
|
44
45
|
)
|
|
45
46
|
case ModelProviderName.openai:
|
|
46
|
-
return
|
|
47
|
+
return LiteLlmAdapter(
|
|
47
48
|
kiln_task=kiln_task,
|
|
48
|
-
config=
|
|
49
|
-
api_key=Config.shared().open_ai_api_key,
|
|
49
|
+
config=LiteLlmConfig(
|
|
50
50
|
model_name=model_name,
|
|
51
51
|
provider_name=provider,
|
|
52
|
+
additional_body_options={
|
|
53
|
+
"api_key": Config.shared().open_ai_api_key,
|
|
54
|
+
},
|
|
52
55
|
),
|
|
53
56
|
prompt_id=prompt_id,
|
|
54
57
|
base_adapter_config=base_adapter_config,
|
|
55
58
|
)
|
|
56
59
|
case ModelProviderName.openai_compatible:
|
|
57
|
-
config =
|
|
58
|
-
return
|
|
60
|
+
config = lite_llm_config(model_name)
|
|
61
|
+
return LiteLlmAdapter(
|
|
59
62
|
kiln_task=kiln_task,
|
|
60
63
|
config=config,
|
|
61
64
|
prompt_id=prompt_id,
|
|
62
65
|
base_adapter_config=base_adapter_config,
|
|
63
66
|
)
|
|
64
|
-
# Use LangchainAdapter for the rest
|
|
65
67
|
case ModelProviderName.groq:
|
|
66
|
-
|
|
68
|
+
return LiteLlmAdapter(
|
|
69
|
+
kiln_task=kiln_task,
|
|
70
|
+
prompt_id=prompt_id,
|
|
71
|
+
base_adapter_config=base_adapter_config,
|
|
72
|
+
config=LiteLlmConfig(
|
|
73
|
+
model_name=model_name,
|
|
74
|
+
provider_name=provider,
|
|
75
|
+
additional_body_options={
|
|
76
|
+
"api_key": Config.shared().groq_api_key,
|
|
77
|
+
},
|
|
78
|
+
),
|
|
79
|
+
)
|
|
67
80
|
case ModelProviderName.amazon_bedrock:
|
|
68
|
-
|
|
81
|
+
return LiteLlmAdapter(
|
|
82
|
+
kiln_task=kiln_task,
|
|
83
|
+
prompt_id=prompt_id,
|
|
84
|
+
base_adapter_config=base_adapter_config,
|
|
85
|
+
config=LiteLlmConfig(
|
|
86
|
+
model_name=model_name,
|
|
87
|
+
provider_name=provider,
|
|
88
|
+
additional_body_options={
|
|
89
|
+
"aws_access_key_id": Config.shared().bedrock_access_key,
|
|
90
|
+
"aws_secret_access_key": Config.shared().bedrock_secret_key,
|
|
91
|
+
# The only region that's widely supported for bedrock
|
|
92
|
+
"aws_region_name": "us-west-2",
|
|
93
|
+
},
|
|
94
|
+
),
|
|
95
|
+
)
|
|
69
96
|
case ModelProviderName.ollama:
|
|
70
|
-
|
|
97
|
+
ollama_base_url = (
|
|
98
|
+
Config.shared().ollama_base_url or "http://localhost:11434"
|
|
99
|
+
)
|
|
100
|
+
return LiteLlmAdapter(
|
|
101
|
+
kiln_task=kiln_task,
|
|
102
|
+
prompt_id=prompt_id,
|
|
103
|
+
base_adapter_config=base_adapter_config,
|
|
104
|
+
config=LiteLlmConfig(
|
|
105
|
+
model_name=model_name,
|
|
106
|
+
provider_name=provider,
|
|
107
|
+
# Set the Ollama base URL for 2 reasons:
|
|
108
|
+
# 1. To use the correct base URL
|
|
109
|
+
# 2. We use Ollama's OpenAI compatible API (/v1), and don't just let litellm use the Ollama API. We use more advanced features like json_schema.
|
|
110
|
+
base_url=ollama_base_url + "/v1",
|
|
111
|
+
additional_body_options={
|
|
112
|
+
# LiteLLM errors without an api_key, even though Ollama doesn't support one.
|
|
113
|
+
"api_key": "NA",
|
|
114
|
+
},
|
|
115
|
+
),
|
|
116
|
+
)
|
|
71
117
|
case ModelProviderName.fireworks_ai:
|
|
72
|
-
|
|
118
|
+
return LiteLlmAdapter(
|
|
119
|
+
kiln_task=kiln_task,
|
|
120
|
+
prompt_id=prompt_id,
|
|
121
|
+
base_adapter_config=base_adapter_config,
|
|
122
|
+
config=LiteLlmConfig(
|
|
123
|
+
model_name=model_name,
|
|
124
|
+
provider_name=provider,
|
|
125
|
+
additional_body_options={
|
|
126
|
+
"api_key": Config.shared().fireworks_api_key,
|
|
127
|
+
},
|
|
128
|
+
),
|
|
129
|
+
)
|
|
130
|
+
case ModelProviderName.anthropic:
|
|
131
|
+
return LiteLlmAdapter(
|
|
132
|
+
kiln_task=kiln_task,
|
|
133
|
+
prompt_id=prompt_id,
|
|
134
|
+
base_adapter_config=base_adapter_config,
|
|
135
|
+
config=LiteLlmConfig(
|
|
136
|
+
model_name=model_name,
|
|
137
|
+
provider_name=provider,
|
|
138
|
+
additional_body_options={
|
|
139
|
+
"api_key": Config.shared().anthropic_api_key,
|
|
140
|
+
},
|
|
141
|
+
),
|
|
142
|
+
)
|
|
143
|
+
case ModelProviderName.gemini_api:
|
|
144
|
+
return LiteLlmAdapter(
|
|
145
|
+
kiln_task=kiln_task,
|
|
146
|
+
prompt_id=prompt_id,
|
|
147
|
+
base_adapter_config=base_adapter_config,
|
|
148
|
+
config=LiteLlmConfig(
|
|
149
|
+
model_name=model_name,
|
|
150
|
+
provider_name=provider,
|
|
151
|
+
additional_body_options={
|
|
152
|
+
"api_key": Config.shared().gemini_api_key,
|
|
153
|
+
},
|
|
154
|
+
),
|
|
155
|
+
)
|
|
156
|
+
case ModelProviderName.vertex:
|
|
157
|
+
return LiteLlmAdapter(
|
|
158
|
+
kiln_task=kiln_task,
|
|
159
|
+
prompt_id=prompt_id,
|
|
160
|
+
base_adapter_config=base_adapter_config,
|
|
161
|
+
config=LiteLlmConfig(
|
|
162
|
+
model_name=model_name,
|
|
163
|
+
provider_name=provider,
|
|
164
|
+
additional_body_options={
|
|
165
|
+
"vertex_project": Config.shared().vertex_project_id,
|
|
166
|
+
"vertex_location": Config.shared().vertex_location,
|
|
167
|
+
},
|
|
168
|
+
),
|
|
169
|
+
)
|
|
170
|
+
case ModelProviderName.together_ai:
|
|
171
|
+
return LiteLlmAdapter(
|
|
172
|
+
kiln_task=kiln_task,
|
|
173
|
+
prompt_id=prompt_id,
|
|
174
|
+
base_adapter_config=base_adapter_config,
|
|
175
|
+
config=LiteLlmConfig(
|
|
176
|
+
model_name=model_name,
|
|
177
|
+
provider_name=provider,
|
|
178
|
+
additional_body_options={
|
|
179
|
+
"api_key": Config.shared().together_api_key,
|
|
180
|
+
},
|
|
181
|
+
),
|
|
182
|
+
)
|
|
183
|
+
case ModelProviderName.azure_openai:
|
|
184
|
+
return LiteLlmAdapter(
|
|
185
|
+
kiln_task=kiln_task,
|
|
186
|
+
prompt_id=prompt_id,
|
|
187
|
+
base_adapter_config=base_adapter_config,
|
|
188
|
+
config=LiteLlmConfig(
|
|
189
|
+
base_url=Config.shared().azure_openai_endpoint,
|
|
190
|
+
model_name=model_name,
|
|
191
|
+
provider_name=provider,
|
|
192
|
+
additional_body_options={
|
|
193
|
+
"api_key": Config.shared().azure_openai_api_key,
|
|
194
|
+
"api_version": "2025-02-01-preview",
|
|
195
|
+
},
|
|
196
|
+
),
|
|
197
|
+
)
|
|
198
|
+
case ModelProviderName.huggingface:
|
|
199
|
+
return LiteLlmAdapter(
|
|
200
|
+
kiln_task=kiln_task,
|
|
201
|
+
prompt_id=prompt_id,
|
|
202
|
+
base_adapter_config=base_adapter_config,
|
|
203
|
+
config=LiteLlmConfig(
|
|
204
|
+
model_name=model_name,
|
|
205
|
+
provider_name=provider,
|
|
206
|
+
additional_body_options={
|
|
207
|
+
"api_key": Config.shared().huggingface_api_key,
|
|
208
|
+
},
|
|
209
|
+
),
|
|
210
|
+
)
|
|
73
211
|
# These are virtual providers that should have mapped to an actual provider in core_provider
|
|
74
212
|
case ModelProviderName.kiln_fine_tune:
|
|
75
213
|
raise ValueError(
|
|
@@ -81,12 +219,3 @@ def adapter_for_task(
|
|
|
81
219
|
)
|
|
82
220
|
case _:
|
|
83
221
|
raise_exhaustive_enum_error(core_provider_name)
|
|
84
|
-
|
|
85
|
-
# We use langchain for all others right now, but moving off it as we touch anything.
|
|
86
|
-
return LangchainAdapter(
|
|
87
|
-
kiln_task,
|
|
88
|
-
model_name=model_name,
|
|
89
|
-
provider=provider,
|
|
90
|
-
prompt_id=prompt_id,
|
|
91
|
-
base_adapter_config=base_adapter_config,
|
|
92
|
-
)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""
|
|
2
|
+
# Evals
|
|
3
|
+
|
|
4
|
+
This module contains the code for evaluating the performance of a model.
|
|
5
|
+
|
|
6
|
+
The submodules contain:
|
|
7
|
+
|
|
8
|
+
- BaseEval: each eval technique implements this interface.
|
|
9
|
+
- G-Eval: an eval implementation, that implements G-Eval and LLM as Judge.
|
|
10
|
+
- EvalRunner: a class that runs an full evaluation (many smaller evals jobs). Includes async parallel processing, and the ability to restart where it left off.
|
|
11
|
+
- EvalRegistry: a registry for all eval implementations.
|
|
12
|
+
|
|
13
|
+
The datamodel for Evals is in the `kiln_ai.datamodel.eval` module.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from . import (
|
|
17
|
+
base_eval,
|
|
18
|
+
eval_runner,
|
|
19
|
+
g_eval,
|
|
20
|
+
registry,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"base_eval",
|
|
25
|
+
"eval_runner",
|
|
26
|
+
"g_eval",
|
|
27
|
+
"registry",
|
|
28
|
+
]
|
|
@@ -139,7 +139,10 @@ class EvalRunner:
|
|
|
139
139
|
for run_config in self.run_configs or []:
|
|
140
140
|
already_run[eval_config.id][run_config.id] = set()
|
|
141
141
|
for run in eval_config.runs(readonly=True):
|
|
142
|
-
if
|
|
142
|
+
if (
|
|
143
|
+
run.task_run_config_id is not None
|
|
144
|
+
and run.task_run_config_id in already_run[eval_config.id]
|
|
145
|
+
):
|
|
143
146
|
already_run[eval_config.id][run.task_run_config_id].add(
|
|
144
147
|
run.dataset_id
|
|
145
148
|
)
|
kiln_ai/adapters/eval/g_eval.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import math
|
|
2
2
|
from typing import Dict, List, Tuple
|
|
3
3
|
|
|
4
|
+
from litellm.types.utils import ChatCompletionTokenLogprob
|
|
5
|
+
|
|
4
6
|
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
5
7
|
from kiln_ai.adapters.eval.base_eval import BaseEval
|
|
6
8
|
from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig, RunOutput
|
|
@@ -8,7 +10,6 @@ from kiln_ai.adapters.prompt_builders import PromptGenerators
|
|
|
8
10
|
from kiln_ai.datamodel import Project, Task, TaskRun
|
|
9
11
|
from kiln_ai.datamodel.eval import EvalConfig, EvalConfigType, EvalScores
|
|
10
12
|
from kiln_ai.datamodel.task import RunConfig
|
|
11
|
-
from openai.types.chat import ChatCompletionTokenLogprob
|
|
12
13
|
|
|
13
14
|
# all the tokens we score for, and their float scores.
|
|
14
15
|
TOKEN_TO_SCORE_MAP: Dict[str, float] = {
|
|
@@ -296,9 +297,12 @@ The model produced the following output for the task:
|
|
|
296
297
|
|
|
297
298
|
total_score = 0.0
|
|
298
299
|
total_probability = 0.0
|
|
300
|
+
top_logprobs_contains_primary_token = False
|
|
299
301
|
|
|
300
|
-
# Process all valid scoring tokens
|
|
302
|
+
# Process all valid scoring tokens from alternatives
|
|
301
303
|
for top_logprob in token_logprob.top_logprobs:
|
|
304
|
+
if top_logprob.token == token_logprob.token:
|
|
305
|
+
top_logprobs_contains_primary_token = True
|
|
302
306
|
token_score = self.score_from_token_string(top_logprob.token)
|
|
303
307
|
if token_score is not None:
|
|
304
308
|
# Convert logprob to probability
|
|
@@ -306,9 +310,21 @@ The model produced the following output for the task:
|
|
|
306
310
|
total_score += token_score * probability
|
|
307
311
|
total_probability += probability
|
|
308
312
|
|
|
313
|
+
# Weird OpenAI 4o bug - sometimes the primary token is included in the top logprobs, sometimes not.
|
|
314
|
+
# Add the primary token back in if excluded
|
|
315
|
+
if not top_logprobs_contains_primary_token:
|
|
316
|
+
if token_logprob.logprob == -9999.0:
|
|
317
|
+
# Another "bug" - sometimes the logprob is -9999.0. This seems to happen when the rest of the logprobs are tiny probability.
|
|
318
|
+
total_score += primary_token_score * 1.0
|
|
319
|
+
total_probability += 1.0
|
|
320
|
+
else:
|
|
321
|
+
probability = math.exp(token_logprob.logprob)
|
|
322
|
+
total_score += primary_token_score * probability
|
|
323
|
+
total_probability += probability
|
|
324
|
+
|
|
309
325
|
if total_probability <= 0.0:
|
|
310
326
|
raise RuntimeError(
|
|
311
|
-
f"No valid scoring tokens found for {token_logprob.token}. This should never happen. Please file a bug if you see this."
|
|
327
|
+
f"No valid scoring tokens found for {token_logprob.token}. This should never happen as the token has a valid score (so it must be excluded from top logprobs). Please file a bug if you see this."
|
|
312
328
|
)
|
|
313
329
|
|
|
314
330
|
# Normalize by total probability of valid tokens (LLM may have wanted to generate other non-rating tokens, these shouldn't lower score of rating tokens)
|
|
@@ -2,6 +2,7 @@ import math
|
|
|
2
2
|
import pickle
|
|
3
3
|
|
|
4
4
|
import pytest
|
|
5
|
+
|
|
5
6
|
from kiln_ai.adapters.eval.g_eval import TOKEN_TO_SCORE_MAP, GEval, GEvalTask
|
|
6
7
|
from kiln_ai.adapters.eval.test_g_eval_data import serialized_run_output
|
|
7
8
|
from kiln_ai.adapters.ml_model_list import built_in_models
|
|
@@ -392,12 +393,13 @@ def test_rating_token_to_score(test_eval_config, test_run_config):
|
|
|
392
393
|
self.logprob = logprob
|
|
393
394
|
|
|
394
395
|
class MockTokenLogprob:
|
|
395
|
-
def __init__(self, token, top_logprobs):
|
|
396
|
+
def __init__(self, token, top_logprobs, logprob):
|
|
396
397
|
self.token = token
|
|
397
398
|
self.top_logprobs = [MockTopLogprob(t, lp) for t, lp in top_logprobs]
|
|
399
|
+
self.logprob = logprob
|
|
398
400
|
|
|
399
401
|
# Test single token case
|
|
400
|
-
token_logprob = MockTokenLogprob("5", [("5", 0.0)]) # log(1) = 0
|
|
402
|
+
token_logprob = MockTokenLogprob("5", [("5", 0.0)], logprob=1e-8) # log(1) = 0
|
|
401
403
|
score = g_eval.rating_token_to_score(token_logprob)
|
|
402
404
|
assert score == 5.0
|
|
403
405
|
|
|
@@ -408,18 +410,22 @@ def test_rating_token_to_score(test_eval_config, test_run_config):
|
|
|
408
410
|
("4", math.log(0.6)), # 60% probability
|
|
409
411
|
("5", math.log(0.4)), # 40% probability
|
|
410
412
|
],
|
|
413
|
+
logprob=math.log(0.6),
|
|
411
414
|
)
|
|
412
415
|
score = g_eval.rating_token_to_score(token_logprob)
|
|
413
416
|
assert pytest.approx(score) == 4.4 # (4 * 0.6 + 5 * 0.4)
|
|
414
417
|
|
|
415
418
|
# Test invalid token
|
|
416
|
-
token_logprob = MockTokenLogprob(":", [(":", 0.0)])
|
|
419
|
+
token_logprob = MockTokenLogprob(":", [(":", 0.0)], logprob=1e-8)
|
|
417
420
|
assert g_eval.rating_token_to_score(token_logprob) is None
|
|
418
421
|
|
|
419
|
-
# Test
|
|
420
|
-
token_logprob = MockTokenLogprob("5", [])
|
|
421
|
-
|
|
422
|
-
|
|
422
|
+
# Test missing from top logprobs
|
|
423
|
+
token_logprob = MockTokenLogprob("5", [], logprob=1e-8)
|
|
424
|
+
assert pytest.approx(g_eval.rating_token_to_score(token_logprob)) == 5.0
|
|
425
|
+
|
|
426
|
+
# Test missing from top logprobs, with special case logprob
|
|
427
|
+
token_logprob = MockTokenLogprob("5", [], logprob=-9999)
|
|
428
|
+
assert pytest.approx(g_eval.rating_token_to_score(token_logprob)) == 5.0
|
|
423
429
|
|
|
424
430
|
|
|
425
431
|
def test_g_eval_system_instruction():
|
|
@@ -4,7 +4,12 @@ from typing import Literal
|
|
|
4
4
|
from pydantic import BaseModel
|
|
5
5
|
|
|
6
6
|
from kiln_ai.adapters.ml_model_list import built_in_models
|
|
7
|
-
from kiln_ai.datamodel import
|
|
7
|
+
from kiln_ai.datamodel import (
|
|
8
|
+
DatasetSplit,
|
|
9
|
+
FinetuneDataStrategy,
|
|
10
|
+
FineTuneStatusType,
|
|
11
|
+
Task,
|
|
12
|
+
)
|
|
8
13
|
from kiln_ai.datamodel import Finetune as FinetuneModel
|
|
9
14
|
from kiln_ai.utils.name_generator import generate_memorable_name
|
|
10
15
|
|
|
@@ -101,7 +106,7 @@ class BaseFinetuneAdapter(ABC):
|
|
|
101
106
|
train_split_name=train_split_name,
|
|
102
107
|
validation_split_name=validation_split_name,
|
|
103
108
|
parameters=parameters,
|
|
104
|
-
system_message=system_message,
|
|
109
|
+
system_message=cls.augment_system_message(system_message, parent_task),
|
|
105
110
|
thinking_instructions=thinking_instructions,
|
|
106
111
|
parent=parent_task,
|
|
107
112
|
data_strategy=data_strategy,
|
|
@@ -114,6 +119,15 @@ class BaseFinetuneAdapter(ABC):
|
|
|
114
119
|
|
|
115
120
|
return adapter, datamodel
|
|
116
121
|
|
|
122
|
+
@classmethod
|
|
123
|
+
def augment_system_message(cls, system_message: str, task: Task) -> str:
|
|
124
|
+
"""
|
|
125
|
+
Augment the system message with additional instructions, such as JSON instructions.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
# Base implementation does nothing, can be overridden by subclasses
|
|
129
|
+
return system_message
|
|
130
|
+
|
|
117
131
|
@abstractmethod
|
|
118
132
|
async def _start(self, dataset: DatasetSplit) -> None:
|
|
119
133
|
"""
|
|
@@ -3,9 +3,11 @@ from typing import Type
|
|
|
3
3
|
from kiln_ai.adapters.fine_tune.base_finetune import BaseFinetuneAdapter
|
|
4
4
|
from kiln_ai.adapters.fine_tune.fireworks_finetune import FireworksFinetune
|
|
5
5
|
from kiln_ai.adapters.fine_tune.openai_finetune import OpenAIFinetune
|
|
6
|
+
from kiln_ai.adapters.fine_tune.together_finetune import TogetherFinetune
|
|
6
7
|
from kiln_ai.adapters.ml_model_list import ModelProviderName
|
|
7
8
|
|
|
8
9
|
finetune_registry: dict[ModelProviderName, Type[BaseFinetuneAdapter]] = {
|
|
9
10
|
ModelProviderName.openai: OpenAIFinetune,
|
|
10
11
|
ModelProviderName.fireworks_ai: FireworksFinetune,
|
|
12
|
+
ModelProviderName.together_ai: TogetherFinetune,
|
|
11
13
|
}
|
|
@@ -132,11 +132,18 @@ class FireworksFinetune(BaseFinetuneAdapter):
|
|
|
132
132
|
:60
|
|
133
133
|
]
|
|
134
134
|
)
|
|
135
|
-
payload = {
|
|
135
|
+
payload: dict[str, str | dict[str, str | bool]] = {
|
|
136
136
|
"dataset": f"accounts/{account_id}/datasets/{train_file_id}",
|
|
137
137
|
"displayName": display_name,
|
|
138
138
|
"baseModel": self.datamodel.base_model_id,
|
|
139
139
|
}
|
|
140
|
+
# Add W&B config if API key is set
|
|
141
|
+
if Config.shared().wandb_api_key:
|
|
142
|
+
payload["wandbConfig"] = {
|
|
143
|
+
"enabled": True,
|
|
144
|
+
"project": "Kiln_AI",
|
|
145
|
+
"apiKey": Config.shared().wandb_api_key,
|
|
146
|
+
}
|
|
140
147
|
hyperparameters = self.create_payload_parameters(self.datamodel.parameters)
|
|
141
148
|
payload.update(hyperparameters)
|
|
142
149
|
headers = {
|
|
@@ -340,6 +340,7 @@ async def test_start_success(
|
|
|
340
340
|
expected_mode,
|
|
341
341
|
expected_format,
|
|
342
342
|
):
|
|
343
|
+
Config.shared().wandb_api_key = "test-api-key"
|
|
343
344
|
mock_task.output_json_schema = output_schema
|
|
344
345
|
|
|
345
346
|
fireworks_finetune.datamodel.parent = mock_task
|
|
@@ -378,6 +379,24 @@ async def test_start_success(
|
|
|
378
379
|
assert fireworks_finetune.datamodel.structured_output_mode == expected_mode
|
|
379
380
|
assert fireworks_finetune.datamodel.properties["endpoint_version"] == "v2"
|
|
380
381
|
|
|
382
|
+
# check mockclent.post call values
|
|
383
|
+
assert mock_client.post.call_count == 1
|
|
384
|
+
submit_call_values = mock_client.post.call_args[1]
|
|
385
|
+
assert submit_call_values["json"]["wandbConfig"] == {
|
|
386
|
+
"enabled": True,
|
|
387
|
+
"project": "Kiln_AI",
|
|
388
|
+
"apiKey": "test-api-key",
|
|
389
|
+
}
|
|
390
|
+
assert submit_call_values["json"]["baseModel"] == "llama-v2-7b"
|
|
391
|
+
assert (
|
|
392
|
+
submit_call_values["json"]["dataset"]
|
|
393
|
+
== f"accounts/{Config.shared().fireworks_account_id}/datasets/{mock_dataset_id}"
|
|
394
|
+
)
|
|
395
|
+
assert (
|
|
396
|
+
submit_call_values["json"]["displayName"]
|
|
397
|
+
== f"Kiln AI fine-tuning [ID:{fireworks_finetune.datamodel.id}][name:{fireworks_finetune.datamodel.name}]"
|
|
398
|
+
)
|
|
399
|
+
|
|
381
400
|
|
|
382
401
|
async def test_start_api_error(
|
|
383
402
|
fireworks_finetune, mock_dataset, mock_task, mock_api_key
|