kiln-ai 0.8.1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +7 -7
- kiln_ai/adapters/adapter_registry.py +81 -10
- kiln_ai/adapters/data_gen/data_gen_task.py +21 -3
- kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
- kiln_ai/adapters/eval/base_eval.py +164 -0
- kiln_ai/adapters/eval/eval_runner.py +267 -0
- kiln_ai/adapters/eval/g_eval.py +367 -0
- kiln_ai/adapters/eval/registry.py +16 -0
- kiln_ai/adapters/eval/test_base_eval.py +324 -0
- kiln_ai/adapters/eval/test_eval_runner.py +640 -0
- kiln_ai/adapters/eval/test_g_eval.py +497 -0
- kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
- kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
- kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
- kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
- kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +472 -129
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +114 -22
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
- kiln_ai/adapters/ml_model_list.py +434 -93
- kiln_ai/adapters/model_adapters/__init__.py +18 -0
- kiln_ai/adapters/model_adapters/base_adapter.py +250 -0
- kiln_ai/adapters/model_adapters/langchain_adapters.py +309 -0
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +10 -0
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +289 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +199 -0
- kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +105 -97
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +216 -0
- kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +80 -30
- kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +125 -46
- kiln_ai/adapters/ollama_tools.py +0 -1
- kiln_ai/adapters/parsers/__init__.py +10 -0
- kiln_ai/adapters/parsers/base_parser.py +12 -0
- kiln_ai/adapters/parsers/json_parser.py +37 -0
- kiln_ai/adapters/parsers/parser_registry.py +19 -0
- kiln_ai/adapters/parsers/r1_parser.py +69 -0
- kiln_ai/adapters/parsers/test_json_parser.py +81 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
- kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
- kiln_ai/adapters/prompt_builders.py +193 -49
- kiln_ai/adapters/provider_tools.py +91 -36
- kiln_ai/adapters/repair/repair_task.py +18 -19
- kiln_ai/adapters/repair/test_repair_task.py +7 -7
- kiln_ai/adapters/run_output.py +11 -0
- kiln_ai/adapters/test_adapter_registry.py +177 -0
- kiln_ai/adapters/test_generate_docs.py +69 -0
- kiln_ai/adapters/test_ollama_tools.py +0 -1
- kiln_ai/adapters/test_prompt_adaptors.py +25 -18
- kiln_ai/adapters/test_prompt_builders.py +265 -44
- kiln_ai/adapters/test_provider_tools.py +268 -46
- kiln_ai/datamodel/__init__.py +51 -772
- kiln_ai/datamodel/basemodel.py +31 -11
- kiln_ai/datamodel/datamodel_enums.py +58 -0
- kiln_ai/datamodel/dataset_filters.py +114 -0
- kiln_ai/datamodel/dataset_split.py +170 -0
- kiln_ai/datamodel/eval.py +298 -0
- kiln_ai/datamodel/finetune.py +105 -0
- kiln_ai/datamodel/json_schema.py +14 -3
- kiln_ai/datamodel/model_cache.py +8 -3
- kiln_ai/datamodel/project.py +23 -0
- kiln_ai/datamodel/prompt.py +37 -0
- kiln_ai/datamodel/prompt_id.py +83 -0
- kiln_ai/datamodel/strict_mode.py +24 -0
- kiln_ai/datamodel/task.py +181 -0
- kiln_ai/datamodel/task_output.py +321 -0
- kiln_ai/datamodel/task_run.py +164 -0
- kiln_ai/datamodel/test_basemodel.py +80 -2
- kiln_ai/datamodel/test_dataset_filters.py +71 -0
- kiln_ai/datamodel/test_dataset_split.py +127 -6
- kiln_ai/datamodel/test_datasource.py +3 -2
- kiln_ai/datamodel/test_eval_model.py +635 -0
- kiln_ai/datamodel/test_example_models.py +34 -17
- kiln_ai/datamodel/test_json_schema.py +23 -0
- kiln_ai/datamodel/test_model_cache.py +24 -0
- kiln_ai/datamodel/test_model_perf.py +125 -0
- kiln_ai/datamodel/test_models.py +131 -2
- kiln_ai/datamodel/test_prompt_id.py +129 -0
- kiln_ai/datamodel/test_task.py +159 -0
- kiln_ai/utils/config.py +6 -1
- kiln_ai/utils/exhaustive_error.py +6 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/METADATA +45 -7
- kiln_ai-0.12.0.dist-info/RECORD +100 -0
- kiln_ai/adapters/base_adapter.py +0 -191
- kiln_ai/adapters/langchain_adapters.py +0 -256
- kiln_ai-0.8.1.dist-info/RECORD +0 -58
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -6,12 +6,16 @@ from unittest.mock import MagicMock, patch
|
|
|
6
6
|
|
|
7
7
|
import pytest
|
|
8
8
|
|
|
9
|
+
from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter
|
|
10
|
+
from kiln_ai.adapters.run_output import RunOutput
|
|
11
|
+
from kiln_ai.datamodel import Task, TaskRun
|
|
9
12
|
from kiln_ai.datamodel.basemodel import (
|
|
10
13
|
KilnBaseModel,
|
|
11
14
|
KilnParentedModel,
|
|
12
15
|
string_to_valid_name,
|
|
13
16
|
)
|
|
14
17
|
from kiln_ai.datamodel.model_cache import ModelCache
|
|
18
|
+
from kiln_ai.datamodel.task import RunConfig
|
|
15
19
|
|
|
16
20
|
|
|
17
21
|
@pytest.fixture
|
|
@@ -356,7 +360,9 @@ def test_load_from_file_with_cache(test_base_file, tmp_model_cache):
|
|
|
356
360
|
model = KilnBaseModel.load_from_file(test_base_file)
|
|
357
361
|
|
|
358
362
|
# Check that the cache was checked and set
|
|
359
|
-
tmp_model_cache.get_model.assert_called_once_with(
|
|
363
|
+
tmp_model_cache.get_model.assert_called_once_with(
|
|
364
|
+
test_base_file, KilnBaseModel, readonly=False
|
|
365
|
+
)
|
|
360
366
|
tmp_model_cache.set_model.assert_called_once()
|
|
361
367
|
|
|
362
368
|
# Ensure the model is correctly loaded
|
|
@@ -407,7 +413,9 @@ def test_load_from_file_with_cached_model(test_base_file, tmp_model_cache):
|
|
|
407
413
|
model = KilnBaseModel.load_from_file(test_base_file)
|
|
408
414
|
|
|
409
415
|
# Check that the cache was checked and the cached model was returned
|
|
410
|
-
tmp_model_cache.get_model.assert_called_once_with(
|
|
416
|
+
tmp_model_cache.get_model.assert_called_once_with(
|
|
417
|
+
test_base_file, KilnBaseModel, readonly=False
|
|
418
|
+
)
|
|
411
419
|
assert model is cached_model
|
|
412
420
|
|
|
413
421
|
# Assert that open was not called (we used the cached model, not file)
|
|
@@ -469,3 +477,73 @@ def test_from_id_and_parent_path_without_parent():
|
|
|
469
477
|
# Test with None parent_path
|
|
470
478
|
not_found = DefaultParentedModel.from_id_and_parent_path("any-id", None)
|
|
471
479
|
assert not_found is None
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
class MockAdapter(BaseAdapter):
|
|
483
|
+
"""Implementation of BaseAdapter for testing"""
|
|
484
|
+
|
|
485
|
+
async def _run(self, input):
|
|
486
|
+
return RunOutput(output="test output", intermediate_outputs=None)
|
|
487
|
+
|
|
488
|
+
def adapter_name(self) -> str:
|
|
489
|
+
return "test"
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
@pytest.fixture
|
|
493
|
+
def base_task():
|
|
494
|
+
return Task(name="test_task", instruction="test_instruction")
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
@pytest.fixture
|
|
498
|
+
def adapter(base_task):
|
|
499
|
+
return MockAdapter(
|
|
500
|
+
run_config=RunConfig(
|
|
501
|
+
task=base_task,
|
|
502
|
+
model_name="test_model",
|
|
503
|
+
model_provider_name="test_provider",
|
|
504
|
+
prompt_id="simple_prompt_builder",
|
|
505
|
+
),
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
async def test_invoke_parsing_flow(adapter):
|
|
510
|
+
# Mock dependencies
|
|
511
|
+
mock_provider = MagicMock()
|
|
512
|
+
mock_provider.parser = "test_parser"
|
|
513
|
+
|
|
514
|
+
mock_parser = MagicMock()
|
|
515
|
+
mock_parser.parse_output.return_value = RunOutput(
|
|
516
|
+
output="parsed test output", intermediate_outputs={"key": "value"}
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
mock_parser_class = MagicMock(return_value=mock_parser)
|
|
520
|
+
|
|
521
|
+
with (
|
|
522
|
+
patch.object(adapter, "model_provider", return_value=mock_provider),
|
|
523
|
+
patch(
|
|
524
|
+
"kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id",
|
|
525
|
+
return_value=mock_parser_class,
|
|
526
|
+
),
|
|
527
|
+
patch("kiln_ai.adapters.model_adapters.base_adapter.Config") as mock_config,
|
|
528
|
+
):
|
|
529
|
+
# Disable autosaving for this test
|
|
530
|
+
mock_config.shared.return_value.autosave_runs = False
|
|
531
|
+
mock_config.shared.return_value.user_id = "test_user_id"
|
|
532
|
+
|
|
533
|
+
# Execute
|
|
534
|
+
result = await adapter.invoke("test input")
|
|
535
|
+
|
|
536
|
+
# Verify parser was created correctly
|
|
537
|
+
mock_parser_class.assert_called_once_with(structured_output=False)
|
|
538
|
+
|
|
539
|
+
# Verify parsing occurred
|
|
540
|
+
mock_parser.parse_output.assert_called_once()
|
|
541
|
+
parsed_args = mock_parser.parse_output.call_args[1]
|
|
542
|
+
assert isinstance(parsed_args["original_output"], RunOutput)
|
|
543
|
+
assert parsed_args["original_output"].output == "test output"
|
|
544
|
+
|
|
545
|
+
# Verify result contains parsed output
|
|
546
|
+
assert isinstance(result, TaskRun)
|
|
547
|
+
assert result.output.output == "parsed test output"
|
|
548
|
+
assert result.intermediate_outputs == {"key": "value"}
|
|
549
|
+
assert result.input == "test input"
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from pydantic import BaseModel
|
|
3
|
+
|
|
4
|
+
from kiln_ai.datamodel.dataset_filters import (
|
|
5
|
+
AllDatasetFilter,
|
|
6
|
+
DatasetFilterId,
|
|
7
|
+
HighRatingDatasetFilter,
|
|
8
|
+
StaticDatasetFilters,
|
|
9
|
+
TagFilter,
|
|
10
|
+
ThinkingModelDatasetFilter,
|
|
11
|
+
ThinkingModelHighRatedFilter,
|
|
12
|
+
dataset_filter_from_id,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
# Note: Many more filter tests in test_dataset_split.py
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def test_all_dataset_filter_from_id():
|
|
19
|
+
assert dataset_filter_from_id("all") == AllDatasetFilter
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def test_high_rating_dataset_filter_from_id():
|
|
23
|
+
assert dataset_filter_from_id("high_rating") == HighRatingDatasetFilter
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_thinking_model_dataset_filter_from_id():
|
|
27
|
+
assert dataset_filter_from_id("thinking_model") == ThinkingModelDatasetFilter
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_thinking_model_high_rated_dataset_filter_from_id():
|
|
31
|
+
assert (
|
|
32
|
+
dataset_filter_from_id("thinking_model_high_rated")
|
|
33
|
+
== ThinkingModelHighRatedFilter
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def test_all_static_dataset_filters():
|
|
38
|
+
for filter_id in StaticDatasetFilters:
|
|
39
|
+
assert dataset_filter_from_id(filter_id) is not None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ModelTester(BaseModel):
|
|
43
|
+
dsid: DatasetFilterId
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@pytest.mark.parametrize(
|
|
47
|
+
"tag,expected_error,expected_tag",
|
|
48
|
+
[
|
|
49
|
+
("tag::test", False, "test"),
|
|
50
|
+
("tag::other", False, "other"),
|
|
51
|
+
("tag::", True, None),
|
|
52
|
+
("tag", True, None),
|
|
53
|
+
("", True, None),
|
|
54
|
+
],
|
|
55
|
+
)
|
|
56
|
+
def test_tag_filter(tag, expected_error, expected_tag):
|
|
57
|
+
# Check our model validators
|
|
58
|
+
if expected_error:
|
|
59
|
+
with pytest.raises(ValueError):
|
|
60
|
+
ModelTester(dsid=tag)
|
|
61
|
+
else:
|
|
62
|
+
ModelTester(dsid=tag)
|
|
63
|
+
|
|
64
|
+
# Check the constructor
|
|
65
|
+
if expected_tag is None:
|
|
66
|
+
with pytest.raises(ValueError, match="Invalid dataset filter ID:"):
|
|
67
|
+
dataset_filter_from_id(tag)
|
|
68
|
+
else:
|
|
69
|
+
filter = dataset_filter_from_id(tag)
|
|
70
|
+
assert isinstance(filter, TagFilter)
|
|
71
|
+
assert filter.tag == expected_tag
|
|
@@ -3,21 +3,28 @@ from pydantic import ValidationError
|
|
|
3
3
|
|
|
4
4
|
# import datamodel first or we get circular import errors
|
|
5
5
|
from kiln_ai.datamodel import (
|
|
6
|
-
AllDatasetFilter,
|
|
7
|
-
AllSplitDefinition,
|
|
8
6
|
DatasetSplit,
|
|
9
7
|
DatasetSplitDefinition,
|
|
10
8
|
DataSource,
|
|
11
9
|
DataSourceType,
|
|
12
|
-
HighRatingDatasetFilter,
|
|
13
10
|
Task,
|
|
14
11
|
TaskOutput,
|
|
15
12
|
TaskOutputRating,
|
|
16
13
|
TaskOutputRatingType,
|
|
17
14
|
TaskRun,
|
|
15
|
+
)
|
|
16
|
+
from kiln_ai.datamodel.dataset_split import (
|
|
17
|
+
AllSplitDefinition,
|
|
18
18
|
Train60Test20Val20SplitDefinition,
|
|
19
19
|
Train80Test20SplitDefinition,
|
|
20
20
|
)
|
|
21
|
+
from kiln_ai.datamodel.test_dataset_filters import (
|
|
22
|
+
AllDatasetFilter,
|
|
23
|
+
HighRatingDatasetFilter,
|
|
24
|
+
TagFilter,
|
|
25
|
+
ThinkingModelDatasetFilter,
|
|
26
|
+
ThinkingModelHighRatedFilter,
|
|
27
|
+
)
|
|
21
28
|
|
|
22
29
|
|
|
23
30
|
@pytest.fixture
|
|
@@ -39,6 +46,7 @@ def sample_task_runs(sample_task):
|
|
|
39
46
|
task_runs = []
|
|
40
47
|
for i in range(10):
|
|
41
48
|
rating = 5 if i < 6 else 1 # 6 high, 4 low ratings
|
|
49
|
+
tags = ["tag1"] if i < 6 else []
|
|
42
50
|
task_run = TaskRun(
|
|
43
51
|
parent=sample_task,
|
|
44
52
|
input=f"input_{i}",
|
|
@@ -56,6 +64,7 @@ def sample_task_runs(sample_task):
|
|
|
56
64
|
value=rating, type=TaskOutputRatingType.five_star
|
|
57
65
|
),
|
|
58
66
|
),
|
|
67
|
+
tags=tags,
|
|
59
68
|
)
|
|
60
69
|
task_run.save_to_file()
|
|
61
70
|
task_runs.append(task_run)
|
|
@@ -131,10 +140,33 @@ def test_all_dataset_filter(task_run):
|
|
|
131
140
|
|
|
132
141
|
|
|
133
142
|
def test_high_rating_dataset_filter(sample_task_runs):
|
|
143
|
+
num_high_quality = 0
|
|
144
|
+
num_low_quality = 0
|
|
134
145
|
for task_run in sample_task_runs:
|
|
135
|
-
|
|
136
|
-
|
|
146
|
+
if HighRatingDatasetFilter(task_run):
|
|
147
|
+
num_high_quality += 1
|
|
148
|
+
assert task_run.output.rating.is_high_quality() is True
|
|
149
|
+
else:
|
|
150
|
+
num_low_quality += 1
|
|
151
|
+
assert task_run.output.rating.is_high_quality() is False
|
|
152
|
+
|
|
153
|
+
# Test repaired output always considered high quality
|
|
154
|
+
task_run = task_run.model_copy(
|
|
155
|
+
update={
|
|
156
|
+
"repair_instructions": "repair instructions",
|
|
157
|
+
"repaired_output": TaskOutput(
|
|
158
|
+
output="repaired output",
|
|
159
|
+
source=DataSource(
|
|
160
|
+
type=DataSourceType.human,
|
|
161
|
+
properties={"created_by": "test-user"},
|
|
162
|
+
),
|
|
163
|
+
),
|
|
164
|
+
}
|
|
137
165
|
)
|
|
166
|
+
assert HighRatingDatasetFilter(task_run) is True
|
|
167
|
+
|
|
168
|
+
assert num_high_quality == 6
|
|
169
|
+
assert num_low_quality == 4
|
|
138
170
|
|
|
139
171
|
|
|
140
172
|
@pytest.mark.parametrize(
|
|
@@ -173,9 +205,11 @@ def test_dataset_split_with_high_rating_filter(sample_task, sample_task_runs):
|
|
|
173
205
|
"Split Name",
|
|
174
206
|
sample_task,
|
|
175
207
|
Train80Test20SplitDefinition,
|
|
176
|
-
|
|
208
|
+
filter_id="high_rating",
|
|
177
209
|
)
|
|
178
210
|
|
|
211
|
+
assert dataset.filter == "high_rating"
|
|
212
|
+
|
|
179
213
|
# Check that only high-rated task runs are included
|
|
180
214
|
all_ids = []
|
|
181
215
|
for ids in dataset.split_contents.values():
|
|
@@ -232,3 +266,90 @@ def test_smaller_sample(sample_task, sample_task_runs):
|
|
|
232
266
|
|
|
233
267
|
# Now we should have 0 missing runs. It's okay that dataset has newer data.
|
|
234
268
|
assert dataset.missing_count() == 0
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
@pytest.mark.parametrize(
|
|
272
|
+
"thinking_data,expected_result",
|
|
273
|
+
[
|
|
274
|
+
({"reasoning": "Here's my answer"}, True),
|
|
275
|
+
({"chain_of_thought": "Here's my answer"}, True),
|
|
276
|
+
({"unknown": "Here's my answer"}, False),
|
|
277
|
+
({}, False),
|
|
278
|
+
(None, False),
|
|
279
|
+
],
|
|
280
|
+
)
|
|
281
|
+
def test_thinking_model_dataset_filter(
|
|
282
|
+
sample_task_runs, thinking_data, expected_result
|
|
283
|
+
):
|
|
284
|
+
# Create a task run with thinking output
|
|
285
|
+
task_run = sample_task_runs[0].model_copy(
|
|
286
|
+
update={
|
|
287
|
+
"output": TaskOutput(
|
|
288
|
+
output="Let me think about this...\nHere's my answer",
|
|
289
|
+
source=DataSource(
|
|
290
|
+
type=DataSourceType.human,
|
|
291
|
+
properties={"created_by": "test-user"},
|
|
292
|
+
),
|
|
293
|
+
rating=TaskOutputRating(value=5, type=TaskOutputRatingType.five_star),
|
|
294
|
+
),
|
|
295
|
+
"intermediate_outputs": thinking_data,
|
|
296
|
+
}
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
assert ThinkingModelDatasetFilter(task_run) is expected_result
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
@pytest.mark.parametrize(
|
|
303
|
+
"thinking_data,rating,expected_result",
|
|
304
|
+
[
|
|
305
|
+
({"reasoning": "Here's my answer"}, 5, True),
|
|
306
|
+
({"chain_of_thought": "Here's my answer"}, 5, True),
|
|
307
|
+
({"unknown": "Here's my answer"}, 5, False),
|
|
308
|
+
({}, 5, False),
|
|
309
|
+
(None, 5, False),
|
|
310
|
+
({"reasoning": "Here's my answer"}, 1, False),
|
|
311
|
+
({"chain_of_thought": "Here's my answer"}, 1, False),
|
|
312
|
+
({"unknown": "Here's my answer"}, 1, False),
|
|
313
|
+
({}, 1, False),
|
|
314
|
+
(None, 1, False),
|
|
315
|
+
],
|
|
316
|
+
)
|
|
317
|
+
def test_thinking_model_dataset_filter_high_rated(
|
|
318
|
+
sample_task_runs, thinking_data, rating, expected_result
|
|
319
|
+
):
|
|
320
|
+
# Create a task run with thinking output
|
|
321
|
+
task_run = sample_task_runs[0].model_copy(
|
|
322
|
+
update={
|
|
323
|
+
"output": TaskOutput(
|
|
324
|
+
output="Let me think about this...\nHere's my answer",
|
|
325
|
+
source=DataSource(
|
|
326
|
+
type=DataSourceType.human,
|
|
327
|
+
properties={"created_by": "test-user"},
|
|
328
|
+
),
|
|
329
|
+
rating=TaskOutputRating(
|
|
330
|
+
value=rating, type=TaskOutputRatingType.five_star
|
|
331
|
+
),
|
|
332
|
+
),
|
|
333
|
+
"intermediate_outputs": thinking_data,
|
|
334
|
+
}
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
assert ThinkingModelHighRatedFilter(task_run) is expected_result
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def test_tag_dataset_filter(sample_task_runs):
|
|
341
|
+
num_tagged = 0
|
|
342
|
+
num_untagged = 0
|
|
343
|
+
filter = TagFilter("tag1")
|
|
344
|
+
for task_run in sample_task_runs:
|
|
345
|
+
if "tag1" in task_run.tags:
|
|
346
|
+
num_tagged += 1
|
|
347
|
+
assert "tag1" in task_run.tags
|
|
348
|
+
assert filter(task_run) is True
|
|
349
|
+
else:
|
|
350
|
+
num_untagged += 1
|
|
351
|
+
assert "tag1" not in task_run.tags
|
|
352
|
+
assert filter(task_run) is False
|
|
353
|
+
|
|
354
|
+
assert num_tagged == 6
|
|
355
|
+
assert num_untagged == 4
|
|
@@ -18,14 +18,14 @@ def test_valid_synthetic_data_source():
|
|
|
18
18
|
properties={
|
|
19
19
|
"model_name": "GPT-4",
|
|
20
20
|
"model_provider": "OpenAI",
|
|
21
|
-
"
|
|
21
|
+
"prompt_id": "simple_prompt_builder",
|
|
22
22
|
"adapter_name": "langchain",
|
|
23
23
|
},
|
|
24
24
|
)
|
|
25
25
|
assert data_source.type == DataSourceType.synthetic
|
|
26
26
|
assert data_source.properties["model_name"] == "GPT-4"
|
|
27
27
|
assert data_source.properties["model_provider"] == "OpenAI"
|
|
28
|
-
assert data_source.properties["
|
|
28
|
+
assert data_source.properties["prompt_id"] == "simple_prompt_builder"
|
|
29
29
|
assert data_source.properties["adapter_name"] == "langchain"
|
|
30
30
|
|
|
31
31
|
|
|
@@ -85,6 +85,7 @@ def test_prompt_type_optional_for_synthetic():
|
|
|
85
85
|
},
|
|
86
86
|
)
|
|
87
87
|
assert "prompt_builder_name" not in data_source.properties
|
|
88
|
+
assert "prompt_id" not in data_source.properties
|
|
88
89
|
|
|
89
90
|
|
|
90
91
|
def test_private_data_source_properties_not_serialized():
|