kiln-ai 0.14.0__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/eval/base_eval.py +7 -2
- kiln_ai/adapters/fine_tune/base_finetune.py +6 -3
- kiln_ai/adapters/fine_tune/dataset_formatter.py +4 -4
- kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +2 -1
- kiln_ai/adapters/fine_tune/test_base_finetune.py +7 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +3 -3
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +586 -0
- kiln_ai/adapters/fine_tune/vertex_finetune.py +217 -0
- kiln_ai/adapters/ml_model_list.py +318 -37
- kiln_ai/adapters/model_adapters/base_adapter.py +15 -10
- kiln_ai/adapters/model_adapters/litellm_adapter.py +10 -5
- kiln_ai/adapters/provider_tools.py +7 -0
- kiln_ai/adapters/test_provider_tools.py +16 -0
- kiln_ai/datamodel/json_schema.py +24 -7
- kiln_ai/datamodel/task_output.py +9 -5
- kiln_ai/datamodel/task_run.py +29 -5
- kiln_ai/datamodel/test_example_models.py +104 -3
- kiln_ai/datamodel/test_json_schema.py +22 -3
- kiln_ai/datamodel/test_model_perf.py +3 -2
- {kiln_ai-0.14.0.dist-info → kiln_ai-0.15.0.dist-info}/METADATA +3 -2
- {kiln_ai-0.14.0.dist-info → kiln_ai-0.15.0.dist-info}/RECORD +25 -24
- kiln_ai/adapters/test_generate_docs.py +0 -69
- {kiln_ai-0.14.0.dist-info → kiln_ai-0.15.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.14.0.dist-info → kiln_ai-0.15.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -2,11 +2,13 @@ import json
|
|
|
2
2
|
from abc import abstractmethod
|
|
3
3
|
from typing import Dict
|
|
4
4
|
|
|
5
|
+
import jsonschema
|
|
6
|
+
|
|
5
7
|
from kiln_ai.adapters.adapter_registry import adapter_for_task
|
|
6
8
|
from kiln_ai.adapters.ml_model_list import ModelProviderName
|
|
7
9
|
from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
|
|
8
10
|
from kiln_ai.datamodel.eval import Eval, EvalConfig, EvalScores
|
|
9
|
-
from kiln_ai.datamodel.json_schema import
|
|
11
|
+
from kiln_ai.datamodel.json_schema import validate_schema_with_value_error
|
|
10
12
|
from kiln_ai.datamodel.task import RunConfig, TaskOutputRatingType, TaskRun
|
|
11
13
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
12
14
|
|
|
@@ -72,7 +74,10 @@ class BaseEval:
|
|
|
72
74
|
run_output = await run_adapter.invoke(parsed_input)
|
|
73
75
|
|
|
74
76
|
eval_output, intermediate_outputs = await self.run_eval(run_output)
|
|
75
|
-
|
|
77
|
+
|
|
78
|
+
validate_schema_with_value_error(
|
|
79
|
+
eval_output, self.score_schema, "Eval output does not match score schema."
|
|
80
|
+
)
|
|
76
81
|
|
|
77
82
|
return run_output, eval_output, intermediate_outputs
|
|
78
83
|
|
|
@@ -166,9 +166,12 @@ class BaseFinetuneAdapter(ABC):
|
|
|
166
166
|
|
|
167
167
|
# Strict type checking for numeric types
|
|
168
168
|
if expected_type is float and not isinstance(value, float):
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
169
|
+
if isinstance(value, int):
|
|
170
|
+
value = float(value)
|
|
171
|
+
else:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"Parameter {parameter.name} must be a float, got {type(value)}"
|
|
174
|
+
)
|
|
172
175
|
elif expected_type is int and not isinstance(value, int):
|
|
173
176
|
raise ValueError(
|
|
174
177
|
f"Parameter {parameter.name} must be an integer, got {type(value)}"
|
|
@@ -30,8 +30,8 @@ class DatasetFormat(str, Enum):
|
|
|
30
30
|
"huggingface_chat_template_toolcall_jsonl"
|
|
31
31
|
)
|
|
32
32
|
|
|
33
|
-
"""Vertex Gemini
|
|
34
|
-
|
|
33
|
+
"""Vertex Gemini format"""
|
|
34
|
+
VERTEX_GEMINI = "vertex_gemini"
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
@dataclass
|
|
@@ -288,7 +288,7 @@ def generate_huggingface_chat_template_toolcall(
|
|
|
288
288
|
return {"conversations": conversations}
|
|
289
289
|
|
|
290
290
|
|
|
291
|
-
def
|
|
291
|
+
def generate_vertex_gemini(
|
|
292
292
|
training_data: ModelTrainingData,
|
|
293
293
|
) -> Dict[str, Any]:
|
|
294
294
|
"""Generate Vertex Gemini 1.5 format (flash and pro)"""
|
|
@@ -346,7 +346,7 @@ FORMAT_GENERATORS: Dict[DatasetFormat, FormatGenerator] = {
|
|
|
346
346
|
DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL: generate_chat_message_toolcall,
|
|
347
347
|
DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_JSONL: generate_huggingface_chat_template,
|
|
348
348
|
DatasetFormat.HUGGINGFACE_CHAT_TEMPLATE_TOOLCALL_JSONL: generate_huggingface_chat_template_toolcall,
|
|
349
|
-
DatasetFormat.
|
|
349
|
+
DatasetFormat.VERTEX_GEMINI: generate_vertex_gemini,
|
|
350
350
|
}
|
|
351
351
|
|
|
352
352
|
|
|
@@ -4,10 +4,12 @@ from kiln_ai.adapters.fine_tune.base_finetune import BaseFinetuneAdapter
|
|
|
4
4
|
from kiln_ai.adapters.fine_tune.fireworks_finetune import FireworksFinetune
|
|
5
5
|
from kiln_ai.adapters.fine_tune.openai_finetune import OpenAIFinetune
|
|
6
6
|
from kiln_ai.adapters.fine_tune.together_finetune import TogetherFinetune
|
|
7
|
+
from kiln_ai.adapters.fine_tune.vertex_finetune import VertexFinetune
|
|
7
8
|
from kiln_ai.adapters.ml_model_list import ModelProviderName
|
|
8
9
|
|
|
9
10
|
finetune_registry: dict[ModelProviderName, Type[BaseFinetuneAdapter]] = {
|
|
10
11
|
ModelProviderName.openai: OpenAIFinetune,
|
|
11
12
|
ModelProviderName.fireworks_ai: FireworksFinetune,
|
|
12
13
|
ModelProviderName.together_ai: TogetherFinetune,
|
|
14
|
+
ModelProviderName.vertex: VertexFinetune,
|
|
13
15
|
}
|
|
@@ -198,7 +198,8 @@ class FireworksFinetune(BaseFinetuneAdapter):
|
|
|
198
198
|
if not api_key or not account_id:
|
|
199
199
|
raise ValueError("Fireworks API key or account ID not set")
|
|
200
200
|
url = f"https://api.fireworks.ai/v1/accounts/{account_id}/datasets"
|
|
201
|
-
|
|
201
|
+
# First char can't be a digit: https://discord.com/channels/1137072072808472616/1363214412395184350/1363214412395184350
|
|
202
|
+
dataset_id = "kiln-" + str(uuid4())
|
|
202
203
|
payload = {
|
|
203
204
|
"datasetId": dataset_id,
|
|
204
205
|
"dataset": {
|
|
@@ -98,6 +98,13 @@ def test_validate_parameters_valid():
|
|
|
98
98
|
}
|
|
99
99
|
MockFinetune.validate_parameters(valid_params) # Should not raise
|
|
100
100
|
|
|
101
|
+
# Test valid parameters (float as int)
|
|
102
|
+
valid_params = {
|
|
103
|
+
"learning_rate": 1,
|
|
104
|
+
"epochs": 10,
|
|
105
|
+
}
|
|
106
|
+
MockFinetune.validate_parameters(valid_params) # Should not raise
|
|
107
|
+
|
|
101
108
|
|
|
102
109
|
def test_validate_parameters_missing_required():
|
|
103
110
|
# Test missing required parameter
|
|
@@ -15,7 +15,7 @@ from kiln_ai.adapters.fine_tune.dataset_formatter import (
|
|
|
15
15
|
generate_chat_message_toolcall,
|
|
16
16
|
generate_huggingface_chat_template,
|
|
17
17
|
generate_huggingface_chat_template_toolcall,
|
|
18
|
-
|
|
18
|
+
generate_vertex_gemini,
|
|
19
19
|
)
|
|
20
20
|
from kiln_ai.adapters.model_adapters.base_adapter import COT_FINAL_ANSWER_PROMPT
|
|
21
21
|
from kiln_ai.datamodel import (
|
|
@@ -447,7 +447,7 @@ def test_generate_vertex_template():
|
|
|
447
447
|
final_output="test output",
|
|
448
448
|
)
|
|
449
449
|
|
|
450
|
-
result =
|
|
450
|
+
result = generate_vertex_gemini(training_data)
|
|
451
451
|
|
|
452
452
|
assert result == {
|
|
453
453
|
"systemInstruction": {
|
|
@@ -475,7 +475,7 @@ def test_generate_vertex_template_thinking():
|
|
|
475
475
|
thinking_final_answer_prompt="thinking final answer prompt",
|
|
476
476
|
)
|
|
477
477
|
|
|
478
|
-
result =
|
|
478
|
+
result = generate_vertex_gemini(training_data)
|
|
479
479
|
|
|
480
480
|
logger.info(result)
|
|
481
481
|
|
|
@@ -315,7 +315,7 @@ async def test_generate_and_upload_jsonl_success(
|
|
|
315
315
|
"thinking_instructions": thinking_instructions,
|
|
316
316
|
}
|
|
317
317
|
|
|
318
|
-
assert result == mock_dataset_id
|
|
318
|
+
assert result == "kiln-" + mock_dataset_id
|
|
319
319
|
assert mock_client.post.call_count == 2
|
|
320
320
|
assert mock_client.get.call_count == 1
|
|
321
321
|
|