kiln-ai 0.8.1__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (88) hide show
  1. kiln_ai/adapters/__init__.py +7 -7
  2. kiln_ai/adapters/adapter_registry.py +81 -10
  3. kiln_ai/adapters/data_gen/data_gen_task.py +21 -3
  4. kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
  5. kiln_ai/adapters/eval/base_eval.py +164 -0
  6. kiln_ai/adapters/eval/eval_runner.py +267 -0
  7. kiln_ai/adapters/eval/g_eval.py +367 -0
  8. kiln_ai/adapters/eval/registry.py +16 -0
  9. kiln_ai/adapters/eval/test_base_eval.py +324 -0
  10. kiln_ai/adapters/eval/test_eval_runner.py +640 -0
  11. kiln_ai/adapters/eval/test_g_eval.py +497 -0
  12. kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
  14. kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
  15. kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
  16. kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
  17. kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
  18. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +472 -129
  19. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +114 -22
  20. kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
  21. kiln_ai/adapters/ml_model_list.py +434 -93
  22. kiln_ai/adapters/model_adapters/__init__.py +18 -0
  23. kiln_ai/adapters/model_adapters/base_adapter.py +250 -0
  24. kiln_ai/adapters/model_adapters/langchain_adapters.py +309 -0
  25. kiln_ai/adapters/model_adapters/openai_compatible_config.py +10 -0
  26. kiln_ai/adapters/model_adapters/openai_model_adapter.py +289 -0
  27. kiln_ai/adapters/model_adapters/test_base_adapter.py +199 -0
  28. kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +105 -97
  29. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +216 -0
  30. kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +80 -30
  31. kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +125 -46
  32. kiln_ai/adapters/ollama_tools.py +0 -1
  33. kiln_ai/adapters/parsers/__init__.py +10 -0
  34. kiln_ai/adapters/parsers/base_parser.py +12 -0
  35. kiln_ai/adapters/parsers/json_parser.py +37 -0
  36. kiln_ai/adapters/parsers/parser_registry.py +19 -0
  37. kiln_ai/adapters/parsers/r1_parser.py +69 -0
  38. kiln_ai/adapters/parsers/test_json_parser.py +81 -0
  39. kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
  40. kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
  41. kiln_ai/adapters/prompt_builders.py +193 -49
  42. kiln_ai/adapters/provider_tools.py +91 -36
  43. kiln_ai/adapters/repair/repair_task.py +18 -19
  44. kiln_ai/adapters/repair/test_repair_task.py +7 -7
  45. kiln_ai/adapters/run_output.py +11 -0
  46. kiln_ai/adapters/test_adapter_registry.py +177 -0
  47. kiln_ai/adapters/test_generate_docs.py +69 -0
  48. kiln_ai/adapters/test_ollama_tools.py +0 -1
  49. kiln_ai/adapters/test_prompt_adaptors.py +25 -18
  50. kiln_ai/adapters/test_prompt_builders.py +265 -44
  51. kiln_ai/adapters/test_provider_tools.py +268 -46
  52. kiln_ai/datamodel/__init__.py +51 -772
  53. kiln_ai/datamodel/basemodel.py +31 -11
  54. kiln_ai/datamodel/datamodel_enums.py +58 -0
  55. kiln_ai/datamodel/dataset_filters.py +114 -0
  56. kiln_ai/datamodel/dataset_split.py +170 -0
  57. kiln_ai/datamodel/eval.py +298 -0
  58. kiln_ai/datamodel/finetune.py +105 -0
  59. kiln_ai/datamodel/json_schema.py +14 -3
  60. kiln_ai/datamodel/model_cache.py +8 -3
  61. kiln_ai/datamodel/project.py +23 -0
  62. kiln_ai/datamodel/prompt.py +37 -0
  63. kiln_ai/datamodel/prompt_id.py +83 -0
  64. kiln_ai/datamodel/strict_mode.py +24 -0
  65. kiln_ai/datamodel/task.py +181 -0
  66. kiln_ai/datamodel/task_output.py +321 -0
  67. kiln_ai/datamodel/task_run.py +164 -0
  68. kiln_ai/datamodel/test_basemodel.py +80 -2
  69. kiln_ai/datamodel/test_dataset_filters.py +71 -0
  70. kiln_ai/datamodel/test_dataset_split.py +127 -6
  71. kiln_ai/datamodel/test_datasource.py +3 -2
  72. kiln_ai/datamodel/test_eval_model.py +635 -0
  73. kiln_ai/datamodel/test_example_models.py +34 -17
  74. kiln_ai/datamodel/test_json_schema.py +23 -0
  75. kiln_ai/datamodel/test_model_cache.py +24 -0
  76. kiln_ai/datamodel/test_model_perf.py +125 -0
  77. kiln_ai/datamodel/test_models.py +131 -2
  78. kiln_ai/datamodel/test_prompt_id.py +129 -0
  79. kiln_ai/datamodel/test_task.py +159 -0
  80. kiln_ai/utils/config.py +6 -1
  81. kiln_ai/utils/exhaustive_error.py +6 -0
  82. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/METADATA +45 -7
  83. kiln_ai-0.12.0.dist-info/RECORD +100 -0
  84. kiln_ai/adapters/base_adapter.py +0 -191
  85. kiln_ai/adapters/langchain_adapters.py +0 -256
  86. kiln_ai-0.8.1.dist-info/RECORD +0 -58
  87. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/WHEEL +0 -0
  88. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -6,12 +6,16 @@ from unittest.mock import MagicMock, patch
6
6
 
7
7
  import pytest
8
8
 
9
+ from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter
10
+ from kiln_ai.adapters.run_output import RunOutput
11
+ from kiln_ai.datamodel import Task, TaskRun
9
12
  from kiln_ai.datamodel.basemodel import (
10
13
  KilnBaseModel,
11
14
  KilnParentedModel,
12
15
  string_to_valid_name,
13
16
  )
14
17
  from kiln_ai.datamodel.model_cache import ModelCache
18
+ from kiln_ai.datamodel.task import RunConfig
15
19
 
16
20
 
17
21
  @pytest.fixture
@@ -356,7 +360,9 @@ def test_load_from_file_with_cache(test_base_file, tmp_model_cache):
356
360
  model = KilnBaseModel.load_from_file(test_base_file)
357
361
 
358
362
  # Check that the cache was checked and set
359
- tmp_model_cache.get_model.assert_called_once_with(test_base_file, KilnBaseModel)
363
+ tmp_model_cache.get_model.assert_called_once_with(
364
+ test_base_file, KilnBaseModel, readonly=False
365
+ )
360
366
  tmp_model_cache.set_model.assert_called_once()
361
367
 
362
368
  # Ensure the model is correctly loaded
@@ -407,7 +413,9 @@ def test_load_from_file_with_cached_model(test_base_file, tmp_model_cache):
407
413
  model = KilnBaseModel.load_from_file(test_base_file)
408
414
 
409
415
  # Check that the cache was checked and the cached model was returned
410
- tmp_model_cache.get_model.assert_called_once_with(test_base_file, KilnBaseModel)
416
+ tmp_model_cache.get_model.assert_called_once_with(
417
+ test_base_file, KilnBaseModel, readonly=False
418
+ )
411
419
  assert model is cached_model
412
420
 
413
421
  # Assert that open was not called (we used the cached model, not file)
@@ -469,3 +477,73 @@ def test_from_id_and_parent_path_without_parent():
469
477
  # Test with None parent_path
470
478
  not_found = DefaultParentedModel.from_id_and_parent_path("any-id", None)
471
479
  assert not_found is None
480
+
481
+
482
+ class MockAdapter(BaseAdapter):
483
+ """Implementation of BaseAdapter for testing"""
484
+
485
+ async def _run(self, input):
486
+ return RunOutput(output="test output", intermediate_outputs=None)
487
+
488
+ def adapter_name(self) -> str:
489
+ return "test"
490
+
491
+
492
+ @pytest.fixture
493
+ def base_task():
494
+ return Task(name="test_task", instruction="test_instruction")
495
+
496
+
497
+ @pytest.fixture
498
+ def adapter(base_task):
499
+ return MockAdapter(
500
+ run_config=RunConfig(
501
+ task=base_task,
502
+ model_name="test_model",
503
+ model_provider_name="test_provider",
504
+ prompt_id="simple_prompt_builder",
505
+ ),
506
+ )
507
+
508
+
509
+ async def test_invoke_parsing_flow(adapter):
510
+ # Mock dependencies
511
+ mock_provider = MagicMock()
512
+ mock_provider.parser = "test_parser"
513
+
514
+ mock_parser = MagicMock()
515
+ mock_parser.parse_output.return_value = RunOutput(
516
+ output="parsed test output", intermediate_outputs={"key": "value"}
517
+ )
518
+
519
+ mock_parser_class = MagicMock(return_value=mock_parser)
520
+
521
+ with (
522
+ patch.object(adapter, "model_provider", return_value=mock_provider),
523
+ patch(
524
+ "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id",
525
+ return_value=mock_parser_class,
526
+ ),
527
+ patch("kiln_ai.adapters.model_adapters.base_adapter.Config") as mock_config,
528
+ ):
529
+ # Disable autosaving for this test
530
+ mock_config.shared.return_value.autosave_runs = False
531
+ mock_config.shared.return_value.user_id = "test_user_id"
532
+
533
+ # Execute
534
+ result = await adapter.invoke("test input")
535
+
536
+ # Verify parser was created correctly
537
+ mock_parser_class.assert_called_once_with(structured_output=False)
538
+
539
+ # Verify parsing occurred
540
+ mock_parser.parse_output.assert_called_once()
541
+ parsed_args = mock_parser.parse_output.call_args[1]
542
+ assert isinstance(parsed_args["original_output"], RunOutput)
543
+ assert parsed_args["original_output"].output == "test output"
544
+
545
+ # Verify result contains parsed output
546
+ assert isinstance(result, TaskRun)
547
+ assert result.output.output == "parsed test output"
548
+ assert result.intermediate_outputs == {"key": "value"}
549
+ assert result.input == "test input"
@@ -0,0 +1,71 @@
1
+ import pytest
2
+ from pydantic import BaseModel
3
+
4
+ from kiln_ai.datamodel.dataset_filters import (
5
+ AllDatasetFilter,
6
+ DatasetFilterId,
7
+ HighRatingDatasetFilter,
8
+ StaticDatasetFilters,
9
+ TagFilter,
10
+ ThinkingModelDatasetFilter,
11
+ ThinkingModelHighRatedFilter,
12
+ dataset_filter_from_id,
13
+ )
14
+
15
+ # Note: Many more filter tests in test_dataset_split.py
16
+
17
+
18
+ def test_all_dataset_filter_from_id():
19
+ assert dataset_filter_from_id("all") == AllDatasetFilter
20
+
21
+
22
+ def test_high_rating_dataset_filter_from_id():
23
+ assert dataset_filter_from_id("high_rating") == HighRatingDatasetFilter
24
+
25
+
26
+ def test_thinking_model_dataset_filter_from_id():
27
+ assert dataset_filter_from_id("thinking_model") == ThinkingModelDatasetFilter
28
+
29
+
30
+ def test_thinking_model_high_rated_dataset_filter_from_id():
31
+ assert (
32
+ dataset_filter_from_id("thinking_model_high_rated")
33
+ == ThinkingModelHighRatedFilter
34
+ )
35
+
36
+
37
+ def test_all_static_dataset_filters():
38
+ for filter_id in StaticDatasetFilters:
39
+ assert dataset_filter_from_id(filter_id) is not None
40
+
41
+
42
+ class ModelTester(BaseModel):
43
+ dsid: DatasetFilterId
44
+
45
+
46
+ @pytest.mark.parametrize(
47
+ "tag,expected_error,expected_tag",
48
+ [
49
+ ("tag::test", False, "test"),
50
+ ("tag::other", False, "other"),
51
+ ("tag::", True, None),
52
+ ("tag", True, None),
53
+ ("", True, None),
54
+ ],
55
+ )
56
+ def test_tag_filter(tag, expected_error, expected_tag):
57
+ # Check our model validators
58
+ if expected_error:
59
+ with pytest.raises(ValueError):
60
+ ModelTester(dsid=tag)
61
+ else:
62
+ ModelTester(dsid=tag)
63
+
64
+ # Check the constructor
65
+ if expected_tag is None:
66
+ with pytest.raises(ValueError, match="Invalid dataset filter ID:"):
67
+ dataset_filter_from_id(tag)
68
+ else:
69
+ filter = dataset_filter_from_id(tag)
70
+ assert isinstance(filter, TagFilter)
71
+ assert filter.tag == expected_tag
@@ -3,21 +3,28 @@ from pydantic import ValidationError
3
3
 
4
4
  # import datamodel first or we get circular import errors
5
5
  from kiln_ai.datamodel import (
6
- AllDatasetFilter,
7
- AllSplitDefinition,
8
6
  DatasetSplit,
9
7
  DatasetSplitDefinition,
10
8
  DataSource,
11
9
  DataSourceType,
12
- HighRatingDatasetFilter,
13
10
  Task,
14
11
  TaskOutput,
15
12
  TaskOutputRating,
16
13
  TaskOutputRatingType,
17
14
  TaskRun,
15
+ )
16
+ from kiln_ai.datamodel.dataset_split import (
17
+ AllSplitDefinition,
18
18
  Train60Test20Val20SplitDefinition,
19
19
  Train80Test20SplitDefinition,
20
20
  )
21
+ from kiln_ai.datamodel.test_dataset_filters import (
22
+ AllDatasetFilter,
23
+ HighRatingDatasetFilter,
24
+ TagFilter,
25
+ ThinkingModelDatasetFilter,
26
+ ThinkingModelHighRatedFilter,
27
+ )
21
28
 
22
29
 
23
30
  @pytest.fixture
@@ -39,6 +46,7 @@ def sample_task_runs(sample_task):
39
46
  task_runs = []
40
47
  for i in range(10):
41
48
  rating = 5 if i < 6 else 1 # 6 high, 4 low ratings
49
+ tags = ["tag1"] if i < 6 else []
42
50
  task_run = TaskRun(
43
51
  parent=sample_task,
44
52
  input=f"input_{i}",
@@ -56,6 +64,7 @@ def sample_task_runs(sample_task):
56
64
  value=rating, type=TaskOutputRatingType.five_star
57
65
  ),
58
66
  ),
67
+ tags=tags,
59
68
  )
60
69
  task_run.save_to_file()
61
70
  task_runs.append(task_run)
@@ -131,10 +140,33 @@ def test_all_dataset_filter(task_run):
131
140
 
132
141
 
133
142
  def test_high_rating_dataset_filter(sample_task_runs):
143
+ num_high_quality = 0
144
+ num_low_quality = 0
134
145
  for task_run in sample_task_runs:
135
- assert HighRatingDatasetFilter(task_run) is (
136
- task_run.output.rating.is_high_quality()
146
+ if HighRatingDatasetFilter(task_run):
147
+ num_high_quality += 1
148
+ assert task_run.output.rating.is_high_quality() is True
149
+ else:
150
+ num_low_quality += 1
151
+ assert task_run.output.rating.is_high_quality() is False
152
+
153
+ # Test repaired output always considered high quality
154
+ task_run = task_run.model_copy(
155
+ update={
156
+ "repair_instructions": "repair instructions",
157
+ "repaired_output": TaskOutput(
158
+ output="repaired output",
159
+ source=DataSource(
160
+ type=DataSourceType.human,
161
+ properties={"created_by": "test-user"},
162
+ ),
163
+ ),
164
+ }
137
165
  )
166
+ assert HighRatingDatasetFilter(task_run) is True
167
+
168
+ assert num_high_quality == 6
169
+ assert num_low_quality == 4
138
170
 
139
171
 
140
172
  @pytest.mark.parametrize(
@@ -173,9 +205,11 @@ def test_dataset_split_with_high_rating_filter(sample_task, sample_task_runs):
173
205
  "Split Name",
174
206
  sample_task,
175
207
  Train80Test20SplitDefinition,
176
- filter=HighRatingDatasetFilter,
208
+ filter_id="high_rating",
177
209
  )
178
210
 
211
+ assert dataset.filter == "high_rating"
212
+
179
213
  # Check that only high-rated task runs are included
180
214
  all_ids = []
181
215
  for ids in dataset.split_contents.values():
@@ -232,3 +266,90 @@ def test_smaller_sample(sample_task, sample_task_runs):
232
266
 
233
267
  # Now we should have 0 missing runs. It's okay that dataset has newer data.
234
268
  assert dataset.missing_count() == 0
269
+
270
+
271
+ @pytest.mark.parametrize(
272
+ "thinking_data,expected_result",
273
+ [
274
+ ({"reasoning": "Here's my answer"}, True),
275
+ ({"chain_of_thought": "Here's my answer"}, True),
276
+ ({"unknown": "Here's my answer"}, False),
277
+ ({}, False),
278
+ (None, False),
279
+ ],
280
+ )
281
+ def test_thinking_model_dataset_filter(
282
+ sample_task_runs, thinking_data, expected_result
283
+ ):
284
+ # Create a task run with thinking output
285
+ task_run = sample_task_runs[0].model_copy(
286
+ update={
287
+ "output": TaskOutput(
288
+ output="Let me think about this...\nHere's my answer",
289
+ source=DataSource(
290
+ type=DataSourceType.human,
291
+ properties={"created_by": "test-user"},
292
+ ),
293
+ rating=TaskOutputRating(value=5, type=TaskOutputRatingType.five_star),
294
+ ),
295
+ "intermediate_outputs": thinking_data,
296
+ }
297
+ )
298
+
299
+ assert ThinkingModelDatasetFilter(task_run) is expected_result
300
+
301
+
302
+ @pytest.mark.parametrize(
303
+ "thinking_data,rating,expected_result",
304
+ [
305
+ ({"reasoning": "Here's my answer"}, 5, True),
306
+ ({"chain_of_thought": "Here's my answer"}, 5, True),
307
+ ({"unknown": "Here's my answer"}, 5, False),
308
+ ({}, 5, False),
309
+ (None, 5, False),
310
+ ({"reasoning": "Here's my answer"}, 1, False),
311
+ ({"chain_of_thought": "Here's my answer"}, 1, False),
312
+ ({"unknown": "Here's my answer"}, 1, False),
313
+ ({}, 1, False),
314
+ (None, 1, False),
315
+ ],
316
+ )
317
+ def test_thinking_model_dataset_filter_high_rated(
318
+ sample_task_runs, thinking_data, rating, expected_result
319
+ ):
320
+ # Create a task run with thinking output
321
+ task_run = sample_task_runs[0].model_copy(
322
+ update={
323
+ "output": TaskOutput(
324
+ output="Let me think about this...\nHere's my answer",
325
+ source=DataSource(
326
+ type=DataSourceType.human,
327
+ properties={"created_by": "test-user"},
328
+ ),
329
+ rating=TaskOutputRating(
330
+ value=rating, type=TaskOutputRatingType.five_star
331
+ ),
332
+ ),
333
+ "intermediate_outputs": thinking_data,
334
+ }
335
+ )
336
+
337
+ assert ThinkingModelHighRatedFilter(task_run) is expected_result
338
+
339
+
340
+ def test_tag_dataset_filter(sample_task_runs):
341
+ num_tagged = 0
342
+ num_untagged = 0
343
+ filter = TagFilter("tag1")
344
+ for task_run in sample_task_runs:
345
+ if "tag1" in task_run.tags:
346
+ num_tagged += 1
347
+ assert "tag1" in task_run.tags
348
+ assert filter(task_run) is True
349
+ else:
350
+ num_untagged += 1
351
+ assert "tag1" not in task_run.tags
352
+ assert filter(task_run) is False
353
+
354
+ assert num_tagged == 6
355
+ assert num_untagged == 4
@@ -18,14 +18,14 @@ def test_valid_synthetic_data_source():
18
18
  properties={
19
19
  "model_name": "GPT-4",
20
20
  "model_provider": "OpenAI",
21
- "prompt_builder_name": "completion",
21
+ "prompt_id": "simple_prompt_builder",
22
22
  "adapter_name": "langchain",
23
23
  },
24
24
  )
25
25
  assert data_source.type == DataSourceType.synthetic
26
26
  assert data_source.properties["model_name"] == "GPT-4"
27
27
  assert data_source.properties["model_provider"] == "OpenAI"
28
- assert data_source.properties["prompt_builder_name"] == "completion"
28
+ assert data_source.properties["prompt_id"] == "simple_prompt_builder"
29
29
  assert data_source.properties["adapter_name"] == "langchain"
30
30
 
31
31
 
@@ -85,6 +85,7 @@ def test_prompt_type_optional_for_synthetic():
85
85
  },
86
86
  )
87
87
  assert "prompt_builder_name" not in data_source.properties
88
+ assert "prompt_id" not in data_source.properties
88
89
 
89
90
 
90
91
  def test_private_data_source_properties_not_serialized():