kiln-ai 0.8.1__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (88) hide show
  1. kiln_ai/adapters/__init__.py +7 -7
  2. kiln_ai/adapters/adapter_registry.py +81 -10
  3. kiln_ai/adapters/data_gen/data_gen_task.py +21 -3
  4. kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
  5. kiln_ai/adapters/eval/base_eval.py +164 -0
  6. kiln_ai/adapters/eval/eval_runner.py +267 -0
  7. kiln_ai/adapters/eval/g_eval.py +367 -0
  8. kiln_ai/adapters/eval/registry.py +16 -0
  9. kiln_ai/adapters/eval/test_base_eval.py +324 -0
  10. kiln_ai/adapters/eval/test_eval_runner.py +640 -0
  11. kiln_ai/adapters/eval/test_g_eval.py +497 -0
  12. kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
  13. kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
  14. kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
  15. kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
  16. kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
  17. kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
  18. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +472 -129
  19. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +114 -22
  20. kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
  21. kiln_ai/adapters/ml_model_list.py +434 -93
  22. kiln_ai/adapters/model_adapters/__init__.py +18 -0
  23. kiln_ai/adapters/model_adapters/base_adapter.py +250 -0
  24. kiln_ai/adapters/model_adapters/langchain_adapters.py +309 -0
  25. kiln_ai/adapters/model_adapters/openai_compatible_config.py +10 -0
  26. kiln_ai/adapters/model_adapters/openai_model_adapter.py +289 -0
  27. kiln_ai/adapters/model_adapters/test_base_adapter.py +199 -0
  28. kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +105 -97
  29. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +216 -0
  30. kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +80 -30
  31. kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +125 -46
  32. kiln_ai/adapters/ollama_tools.py +0 -1
  33. kiln_ai/adapters/parsers/__init__.py +10 -0
  34. kiln_ai/adapters/parsers/base_parser.py +12 -0
  35. kiln_ai/adapters/parsers/json_parser.py +37 -0
  36. kiln_ai/adapters/parsers/parser_registry.py +19 -0
  37. kiln_ai/adapters/parsers/r1_parser.py +69 -0
  38. kiln_ai/adapters/parsers/test_json_parser.py +81 -0
  39. kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
  40. kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
  41. kiln_ai/adapters/prompt_builders.py +193 -49
  42. kiln_ai/adapters/provider_tools.py +91 -36
  43. kiln_ai/adapters/repair/repair_task.py +18 -19
  44. kiln_ai/adapters/repair/test_repair_task.py +7 -7
  45. kiln_ai/adapters/run_output.py +11 -0
  46. kiln_ai/adapters/test_adapter_registry.py +177 -0
  47. kiln_ai/adapters/test_generate_docs.py +69 -0
  48. kiln_ai/adapters/test_ollama_tools.py +0 -1
  49. kiln_ai/adapters/test_prompt_adaptors.py +25 -18
  50. kiln_ai/adapters/test_prompt_builders.py +265 -44
  51. kiln_ai/adapters/test_provider_tools.py +268 -46
  52. kiln_ai/datamodel/__init__.py +51 -772
  53. kiln_ai/datamodel/basemodel.py +31 -11
  54. kiln_ai/datamodel/datamodel_enums.py +58 -0
  55. kiln_ai/datamodel/dataset_filters.py +114 -0
  56. kiln_ai/datamodel/dataset_split.py +170 -0
  57. kiln_ai/datamodel/eval.py +298 -0
  58. kiln_ai/datamodel/finetune.py +105 -0
  59. kiln_ai/datamodel/json_schema.py +14 -3
  60. kiln_ai/datamodel/model_cache.py +8 -3
  61. kiln_ai/datamodel/project.py +23 -0
  62. kiln_ai/datamodel/prompt.py +37 -0
  63. kiln_ai/datamodel/prompt_id.py +83 -0
  64. kiln_ai/datamodel/strict_mode.py +24 -0
  65. kiln_ai/datamodel/task.py +181 -0
  66. kiln_ai/datamodel/task_output.py +321 -0
  67. kiln_ai/datamodel/task_run.py +164 -0
  68. kiln_ai/datamodel/test_basemodel.py +80 -2
  69. kiln_ai/datamodel/test_dataset_filters.py +71 -0
  70. kiln_ai/datamodel/test_dataset_split.py +127 -6
  71. kiln_ai/datamodel/test_datasource.py +3 -2
  72. kiln_ai/datamodel/test_eval_model.py +635 -0
  73. kiln_ai/datamodel/test_example_models.py +34 -17
  74. kiln_ai/datamodel/test_json_schema.py +23 -0
  75. kiln_ai/datamodel/test_model_cache.py +24 -0
  76. kiln_ai/datamodel/test_model_perf.py +125 -0
  77. kiln_ai/datamodel/test_models.py +131 -2
  78. kiln_ai/datamodel/test_prompt_id.py +129 -0
  79. kiln_ai/datamodel/test_task.py +159 -0
  80. kiln_ai/utils/config.py +6 -1
  81. kiln_ai/utils/exhaustive_error.py +6 -0
  82. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/METADATA +45 -7
  83. kiln_ai-0.12.0.dist-info/RECORD +100 -0
  84. kiln_ai/adapters/base_adapter.py +0 -191
  85. kiln_ai/adapters/langchain_adapters.py +0 -256
  86. kiln_ai-0.8.1.dist-info/RECORD +0 -58
  87. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/WHEEL +0 -0
  88. {kiln_ai-0.8.1.dist-info → kiln_ai-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -11,7 +11,6 @@ from kiln_ai.datamodel import (
11
11
  Finetune,
12
12
  Project,
13
13
  Task,
14
- TaskDeterminism,
15
14
  TaskOutput,
16
15
  TaskOutputRating,
17
16
  TaskOutputRatingType,
@@ -125,7 +124,6 @@ def test_structured_output_workflow(tmp_path):
125
124
  name="Structured Output Task",
126
125
  parent=project,
127
126
  instruction="Generate a JSON object with name and age",
128
- determinism=TaskDeterminism.semantic_match,
129
127
  output_json_schema=json.dumps(
130
128
  {
131
129
  "type": "object",
@@ -142,7 +140,7 @@ def test_structured_output_workflow(tmp_path):
142
140
 
143
141
  # Create runs
144
142
  runs = []
145
- for source in DataSourceType:
143
+ for source in [DataSourceType.human, DataSourceType.synthetic]:
146
144
  for _ in range(2):
147
145
  task_run = TaskRun(
148
146
  input="Generate info for John Doe",
@@ -157,7 +155,7 @@ def test_structured_output_workflow(tmp_path):
157
155
  "adapter_name": "TestAdapter",
158
156
  "model_name": "GPT-4",
159
157
  "model_provider": "OpenAI",
160
- "prompt_builder_name": "TestPromptBuilder",
158
+ "prompt_id": "simple_prompt_builder",
161
159
  },
162
160
  ),
163
161
  parent=task,
@@ -216,9 +214,9 @@ def test_structured_output_workflow(tmp_path):
216
214
 
217
215
  assert loaded_task.name == "Structured Output Task"
218
216
  assert len(loaded_task.requirements) == 2
219
- assert len(loaded_task.runs()) == 5
220
-
221
217
  loaded_runs = loaded_task.runs()
218
+ assert len(loaded_runs) == 5
219
+
222
220
  for task_run in loaded_runs:
223
221
  output = task_run.output
224
222
  assert output.rating is not None
@@ -284,6 +282,9 @@ def test_task_output_requirement_rating_keys(tmp_path):
284
282
  assert task_run.output.rating.requirement_ratings is not None
285
283
 
286
284
 
285
+ _schema_match = "This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema."
286
+
287
+
287
288
  def test_task_output_schema_validation(tmp_path):
288
289
  # Create a project, task, and example hierarchy
289
290
  project = Project(name="Test Project", path=(tmp_path / "test_project"))
@@ -321,12 +322,24 @@ def test_task_output_schema_validation(tmp_path):
321
322
  task_output.save_to_file()
322
323
 
323
324
  # changing to invalid output
324
- with pytest.raises(ValueError, match="does not match task output schema"):
325
+ with pytest.raises(
326
+ ValueError,
327
+ match=_schema_match,
328
+ ):
325
329
  task_output.output.output = '{"name": "John Doe", "age": "thirty"}'
326
330
  task_output.save_to_file()
327
331
 
332
+ # changing to invalid output from loaded model
333
+ loaded_task_output = TaskRun.load_from_file(task_output.path)
334
+ with pytest.raises(
335
+ ValueError,
336
+ match=_schema_match,
337
+ ):
338
+ loaded_task_output.output.output = '{"name": "John Doe", "age": "forty"}'
339
+ loaded_task_output.save_to_file()
340
+
328
341
  # Invalid case: output does not match task output schema
329
- with pytest.raises(ValueError, match="does not match task output schema"):
342
+ with pytest.raises(ValueError, match=_schema_match):
330
343
  task_output = TaskRun(
331
344
  input="Test input",
332
345
  input_source=DataSource(
@@ -382,12 +395,18 @@ def test_task_input_schema_validation(tmp_path):
382
395
  valid_task_output.save_to_file()
383
396
 
384
397
  # Changing to invalid input
385
- with pytest.raises(ValueError, match="does not match task input schema"):
398
+ with pytest.raises(ValueError, match=_schema_match):
386
399
  valid_task_output.input = '{"name": "John Doe", "age": "thirty"}'
387
400
  valid_task_output.save_to_file()
388
401
 
402
+ # loading from file, then changing to invalid input
403
+ loaded_task_output = TaskRun.load_from_file(valid_task_output.path)
404
+ with pytest.raises(ValueError, match=_schema_match):
405
+ loaded_task_output.input = '{"name": "John Doe", "age": "thirty"}'
406
+ loaded_task_output.save_to_file()
407
+
389
408
  # Invalid case: input does not match task input schema
390
- with pytest.raises(ValueError, match="does not match task input schema"):
409
+ with pytest.raises(ValueError, match=_schema_match):
391
410
  task_output = TaskRun(
392
411
  input='{"name": "John Doe", "age": "thirty"}',
393
412
  input_source=DataSource(
@@ -451,7 +470,7 @@ def test_valid_synthetic_task_output():
451
470
  "adapter_name": "TestAdapter",
452
471
  "model_name": "GPT-4",
453
472
  "model_provider": "OpenAI",
454
- "prompt_builder_name": "TestPromptBuilder",
473
+ "prompt_id": "simple_prompt_builder",
455
474
  },
456
475
  ),
457
476
  )
@@ -459,7 +478,7 @@ def test_valid_synthetic_task_output():
459
478
  assert output.source.properties["adapter_name"] == "TestAdapter"
460
479
  assert output.source.properties["model_name"] == "GPT-4"
461
480
  assert output.source.properties["model_provider"] == "OpenAI"
462
- assert output.source.properties["prompt_builder_name"] == "TestPromptBuilder"
481
+ assert output.source.properties["prompt_id"] == "simple_prompt_builder"
463
482
 
464
483
 
465
484
  def test_invalid_synthetic_task_output_missing_keys():
@@ -488,23 +507,21 @@ def test_invalid_synthetic_task_output_empty_values():
488
507
  "adapter_name": "TestAdapter",
489
508
  "model_name": "",
490
509
  "model_provider": "OpenAI",
491
- "prompt_builder_name": "TestPromptBuilder",
510
+ "prompt_id": "simple_prompt_builder",
492
511
  },
493
512
  ),
494
513
  )
495
514
 
496
515
 
497
516
  def test_invalid_synthetic_task_output_non_string_values():
498
- with pytest.raises(
499
- ValidationError, match="'prompt_builder_name' must be of type str"
500
- ):
517
+ with pytest.raises(ValidationError, match="'prompt_id' must be of type str"):
501
518
  DataSource(
502
519
  type=DataSourceType.synthetic,
503
520
  properties={
504
521
  "adapter_name": "TestAdapter",
505
522
  "model_name": "GPT-4",
506
523
  "model_provider": "OpenAI",
507
- "prompt_builder_name": 123,
524
+ "prompt_id": 123,
508
525
  },
509
526
  )
510
527
 
@@ -4,6 +4,7 @@ from pydantic import BaseModel
4
4
  from kiln_ai.datamodel.json_schema import (
5
5
  JsonObjectSchema,
6
6
  schema_from_json_str,
7
+ string_to_json_key,
7
8
  validate_schema,
8
9
  )
9
10
 
@@ -123,3 +124,25 @@ def test_triangle_schema():
123
124
  validate_schema({"a": 1, "b": 2, "c": 3}, json_triangle_schema)
124
125
  with pytest.raises(Exception):
125
126
  validate_schema({"a": 1, "b": 2, "c": "3"}, json_triangle_schema)
127
+
128
+
129
+ @pytest.mark.parametrize(
130
+ "input_str,expected",
131
+ [
132
+ ("hello world", "hello_world"),
133
+ ("Hello World", "hello_world"),
134
+ ("hello_world", "hello_world"),
135
+ ("HELLO WORLD", "hello_world"),
136
+ ("hello123", "hello123"),
137
+ ("hello-world", "helloworld"),
138
+ ("hello!@#$%^&*()world", "helloworld"),
139
+ (" hello world ", "hello__world"),
140
+ ("hello__world", "hello__world"),
141
+ ("", ""),
142
+ ("!@#$%", ""),
143
+ ("snake_case_string", "snake_case_string"),
144
+ ("camelCaseString", "camelcasestring"),
145
+ ],
146
+ )
147
+ def test_string_to_json_key(input_str: str, expected: str):
148
+ assert string_to_json_key(input_str) == expected
@@ -242,3 +242,27 @@ def test_check_timestamp_granularity_linux_error():
242
242
  cache = ModelCache()
243
243
  assert cache._check_timestamp_granularity() is False
244
244
  assert cache._enabled is False
245
+
246
+
247
+ def test_get_model_readonly(model_cache, test_path):
248
+ if not model_cache._enabled:
249
+ pytest.skip("Cache is disabled on this fs")
250
+
251
+ model = ModelTest(name="test", value=123)
252
+ mtime_ns = test_path.stat().st_mtime_ns
253
+
254
+ # Set the model in the cache
255
+ model_cache.set_model(test_path, model, mtime_ns)
256
+
257
+ # Get the model in readonly mode
258
+ readonly_model = model_cache.get_model(test_path, ModelTest, readonly=True)
259
+ # Get a regular (copied) model
260
+ copied_model = model_cache.get_model(test_path, ModelTest)
261
+
262
+ # The readonly model should be the exact same instance as the cached model
263
+ assert readonly_model is model_cache.model_cache[test_path][0]
264
+ # While the regular get should be a different instance
265
+ assert copied_model is not model_cache.model_cache[test_path][0]
266
+
267
+ # Both should have the same data
268
+ assert readonly_model == copied_model == model
@@ -0,0 +1,125 @@
1
+ import shutil
2
+ import uuid
3
+
4
+ import pytest
5
+
6
+ from kiln_ai.datamodel import (
7
+ DataSource,
8
+ DataSourceType,
9
+ Project,
10
+ Task,
11
+ TaskOutput,
12
+ TaskRun,
13
+ )
14
+
15
+ test_json_schema = """{
16
+ "type": "object",
17
+ "properties": {
18
+ "setup": {
19
+ "description": "The setup of the joke",
20
+ "title": "Setup",
21
+ "type": "string"
22
+ },
23
+ "punchline": {
24
+ "description": "The punchline to the joke",
25
+ "title": "Punchline",
26
+ "type": "string"
27
+ },
28
+ "rating": {
29
+ "anyOf": [
30
+ {
31
+ "type": "integer"
32
+ },
33
+ {
34
+ "type": "null"
35
+ }
36
+ ],
37
+ "default": null,
38
+ "description": "How funny the joke is, from 1 to 10",
39
+ "title": "Rating"
40
+ }
41
+ },
42
+ "required": [
43
+ "setup",
44
+ "punchline"
45
+ ]
46
+ }
47
+ """
48
+
49
+
50
+ @pytest.fixture
51
+ def task_run(tmp_path):
52
+ # setup a valid project/task/task_run for testing
53
+ output_source = DataSource(
54
+ type=DataSourceType.synthetic,
55
+ properties={
56
+ "model_name": "test-model",
57
+ "model_provider": "test-provider",
58
+ "adapter_name": "test-adapter",
59
+ },
60
+ )
61
+
62
+ project_path = tmp_path / "project.kiln"
63
+ project = Project(name="Test Project", path=project_path)
64
+ project.save_to_file()
65
+ task = Task(
66
+ name="Test Task",
67
+ instruction="Test Instruction",
68
+ parent=project,
69
+ output_json_schema=test_json_schema,
70
+ input_json_schema=test_json_schema,
71
+ )
72
+
73
+ task.save_to_file()
74
+
75
+ task_output = TaskOutput(
76
+ output='{"setup": "Why did the chicken cross the road?", "punchline": "To get to the other side"}',
77
+ source=DataSource(
78
+ type=DataSourceType.synthetic,
79
+ properties={
80
+ "model_name": "test-model",
81
+ "model_provider": "test-provider",
82
+ "adapter_name": "test-adapter",
83
+ },
84
+ ),
85
+ )
86
+
87
+ # Save for later usage
88
+ task_run = TaskRun(
89
+ input='{"setup": "Why did the chicken cross the road?", "punchline": "To get to the other side"}',
90
+ input_source=output_source,
91
+ output=task_output,
92
+ )
93
+ task_run.parent = task
94
+ task_run.save_to_file()
95
+
96
+ return task_run
97
+
98
+
99
+ @pytest.mark.benchmark
100
+ def test_benchmark_load_from_file(benchmark, task_run):
101
+ task_run_path = task_run.path
102
+
103
+ iterations = 500
104
+ total_time = 0
105
+
106
+ for _ in range(iterations):
107
+ # Copy the task to a new temp path, so we don't get warm loads/cached loads
108
+ temp_path = task_run.path.parent / f"temp_task_run_{uuid.uuid4()}.json"
109
+ shutil.copy(str(task_run_path), str(temp_path))
110
+
111
+ # only time loading the model (and one accessor for delayed validation)
112
+ start_time = benchmark._timer()
113
+ loaded = TaskRun.load_from_file(temp_path)
114
+ assert loaded.id == task_run.id
115
+ end_time = benchmark._timer()
116
+
117
+ total_time += end_time - start_time
118
+
119
+ avg_time_per_iteration = total_time / iterations
120
+ ops_per_second = 1.0 / avg_time_per_iteration
121
+
122
+ # I get 8k ops per second on my MBP. Lower value here for CI.
123
+ # Prior to optimization was 290 ops per second.
124
+ if ops_per_second < 1000:
125
+ pytest.fail(f"Ops per second: {ops_per_second:.6f}, expected more than 1k ops")
@@ -9,7 +9,9 @@ from kiln_ai.datamodel import (
9
9
  DataSource,
10
10
  DataSourceType,
11
11
  Finetune,
12
+ FinetuneDataStrategy,
12
13
  Project,
14
+ Prompt,
13
15
  Task,
14
16
  TaskOutput,
15
17
  TaskRun,
@@ -70,6 +72,20 @@ def test_save_to_file(test_project_file):
70
72
  assert data["description"] == "Test Description"
71
73
 
72
74
 
75
+ def test_save_to_file_non_ascii(test_project_file):
76
+ project = Project(
77
+ name="Test Project", description="Chúc mừng!", path=test_project_file
78
+ )
79
+ project.save_to_file()
80
+
81
+ with open(test_project_file, "r", encoding="utf-8") as file:
82
+ data = json.load(file)
83
+
84
+ assert data["v"] == 1
85
+ assert data["name"] == "Test Project"
86
+ assert data["description"] == "Chúc mừng!"
87
+
88
+
73
89
  def test_task_defaults():
74
90
  task = Task(name="Test Task", instruction="Test Instruction")
75
91
  assert task.description is None
@@ -369,7 +385,7 @@ def test_task_run_input_source_validation(tmp_path):
369
385
  assert task_run.input_source is not None
370
386
 
371
387
  # Test 3: Creating without input_source should fail when strict mode is on
372
- with patch("kiln_ai.datamodel.strict_mode", return_value=True):
388
+ with patch("kiln_ai.datamodel.task_run.strict_mode", return_value=True):
373
389
  with pytest.raises(ValueError) as exc_info:
374
390
  task_run = TaskRun(
375
391
  input="test input 3",
@@ -426,7 +442,7 @@ def test_task_output_source_validation(tmp_path):
426
442
  assert task_output.source is not None
427
443
 
428
444
  # Test 3: Creating without source should fail when strict mode is on
429
- with patch("kiln_ai.datamodel.strict_mode", return_value=True):
445
+ with patch("kiln_ai.datamodel.task_output.strict_mode", return_value=True):
430
446
  with pytest.raises(ValueError) as exc_info:
431
447
  task_output = TaskOutput(
432
448
  output="test output 3",
@@ -488,3 +504,116 @@ def test_task_run_tags_validation():
488
504
  tags=["valid_tag", "invalid tag"],
489
505
  )
490
506
  assert "Tags cannot contain spaces. Try underscores." in str(exc_info.value)
507
+
508
+
509
+ def test_prompt_validation():
510
+ prompt = Prompt(name="Test Prompt Name", prompt="Test Prompt")
511
+ assert prompt.name == "Test Prompt Name"
512
+ assert prompt.prompt == "Test Prompt"
513
+
514
+ with pytest.raises(ValidationError):
515
+ Prompt(name="Test Prompt")
516
+
517
+ with pytest.raises(ValidationError):
518
+ Prompt(name="Test Prompt", prompt=None)
519
+
520
+ with pytest.raises(ValidationError):
521
+ Prompt(name="Test Prompt", prompt="")
522
+
523
+ with pytest.raises(ValidationError):
524
+ Prompt(prompt="Test Prompt")
525
+
526
+
527
+ def test_prompt_parent_task():
528
+ task = Task(name="Test Task", instruction="Test Instruction")
529
+ prompt = Prompt(name="Test Prompt", prompt="Test Prompt", parent=task)
530
+ assert prompt.parent == task
531
+
532
+
533
+ @pytest.mark.parametrize(
534
+ "thinking_instructions,data_strategy,should_raise,expected_message",
535
+ [
536
+ # Test 1: Valid case - no thinking instructions with final_only
537
+ (
538
+ None,
539
+ FinetuneDataStrategy.final_only,
540
+ False,
541
+ None,
542
+ ),
543
+ # Test 2: Valid case - thinking instructions with final_and_intermediate
544
+ (
545
+ "Think step by step",
546
+ FinetuneDataStrategy.final_and_intermediate,
547
+ False,
548
+ None,
549
+ ),
550
+ # Test 3: Invalid case - thinking instructions with final_only
551
+ (
552
+ "Think step by step",
553
+ FinetuneDataStrategy.final_only,
554
+ True,
555
+ "Thinking instructions can only be used when data_strategy is final_and_intermediate",
556
+ ),
557
+ # Test 4: Invalid case - no thinking instructions with final_and_intermediate
558
+ (
559
+ None,
560
+ FinetuneDataStrategy.final_and_intermediate,
561
+ True,
562
+ "Thinking instructions are required when data_strategy is final_and_intermediate",
563
+ ),
564
+ ],
565
+ )
566
+ def test_finetune_thinking_instructions_validation(
567
+ thinking_instructions, data_strategy, should_raise, expected_message
568
+ ):
569
+ base_params = {
570
+ "name": "test-finetune",
571
+ "provider": "openai",
572
+ "base_model_id": "gpt-3.5-turbo",
573
+ "dataset_split_id": "split1",
574
+ "system_message": "test message",
575
+ "data_strategy": data_strategy,
576
+ }
577
+
578
+ if thinking_instructions is not None:
579
+ base_params["thinking_instructions"] = thinking_instructions
580
+
581
+ if should_raise:
582
+ with pytest.raises(ValueError) as exc_info:
583
+ Finetune(**base_params)
584
+ assert expected_message in str(exc_info.value)
585
+ else:
586
+ finetune = Finetune(**base_params)
587
+ assert finetune.thinking_instructions == thinking_instructions
588
+ assert finetune.data_strategy == data_strategy
589
+
590
+
591
+ @pytest.mark.parametrize(
592
+ "intermediate_outputs,expected",
593
+ [
594
+ # No intermediate outputs
595
+ (None, False),
596
+ # Empty intermediate outputs
597
+ ({}, False),
598
+ # Only chain_of_thought
599
+ ({"chain_of_thought": "thinking process"}, True),
600
+ # Only reasoning
601
+ ({"reasoning": "reasoning process"}, True),
602
+ # Both chain_of_thought and reasoning
603
+ (
604
+ {"chain_of_thought": "thinking process", "reasoning": "reasoning process"},
605
+ True,
606
+ ),
607
+ # Other intermediate outputs but no thinking data
608
+ ({"other_output": "some data"}, False),
609
+ # Mixed other outputs with thinking data
610
+ ({"chain_of_thought": "thinking process", "other_output": "some data"}, True),
611
+ ],
612
+ )
613
+ def test_task_run_has_thinking_training_data(intermediate_outputs, expected):
614
+ task_run = TaskRun(
615
+ input="test input",
616
+ output=TaskOutput(output="test output"),
617
+ intermediate_outputs=intermediate_outputs,
618
+ )
619
+ assert task_run.has_thinking_training_data() == expected
@@ -0,0 +1,129 @@
1
+ import pytest
2
+ from pydantic import BaseModel, ValidationError
3
+
4
+ from kiln_ai.datamodel import (
5
+ PromptGenerators,
6
+ PromptId,
7
+ )
8
+ from kiln_ai.datamodel.prompt_id import is_frozen_prompt
9
+
10
+
11
+ # Test model to validate the PromptId type
12
+ class ModelTester(BaseModel):
13
+ prompt_id: PromptId
14
+
15
+
16
+ def test_valid_prompt_generator_names():
17
+ """Test that valid prompt generator names are accepted"""
18
+ for generator in PromptGenerators:
19
+ model = ModelTester(prompt_id=generator.value)
20
+ assert model.prompt_id == generator.value
21
+
22
+
23
+ def test_valid_saved_prompt_id():
24
+ """Test that valid saved prompt IDs are accepted"""
25
+ valid_id = "id::prompt_789"
26
+ model = ModelTester(prompt_id=valid_id)
27
+ assert model.prompt_id == valid_id
28
+
29
+
30
+ def test_valid_fine_tune_prompt_id():
31
+ """Test that valid fine-tune prompt IDs are accepted"""
32
+ valid_id = "fine_tune_prompt::ft_123456"
33
+ model = ModelTester(prompt_id=valid_id)
34
+ assert model.prompt_id == valid_id
35
+
36
+
37
+ @pytest.mark.parametrize(
38
+ "invalid_id",
39
+ [
40
+ pytest.param("id::project_123::task_456", id="missing_prompt_id"),
41
+ pytest.param("id::task_456::prompt_789", id="too_many_parts"),
42
+ pytest.param("id::", id="empty_parts"),
43
+ ],
44
+ )
45
+ def test_invalid_saved_prompt_id_format(invalid_id):
46
+ """Test that invalid saved prompt ID formats are rejected"""
47
+ with pytest.raises(ValidationError, match="Invalid saved prompt ID"):
48
+ ModelTester(prompt_id=invalid_id)
49
+
50
+
51
+ @pytest.mark.parametrize(
52
+ "invalid_id,expected_error",
53
+ [
54
+ ("fine_tune_prompt::", "Invalid fine-tune prompt ID: fine_tune_prompt::"),
55
+ ("fine_tune_prompt", "Invalid prompt ID: fine_tune_prompt"),
56
+ ],
57
+ )
58
+ def test_invalid_fine_tune_prompt_id_format(invalid_id, expected_error):
59
+ """Test that invalid fine-tune prompt ID formats are rejected"""
60
+ with pytest.raises(ValidationError, match=expected_error):
61
+ ModelTester(prompt_id=invalid_id)
62
+
63
+
64
+ def test_completely_invalid_formats():
65
+ """Test that completely invalid formats are rejected"""
66
+ invalid_ids = [
67
+ "", # Empty string
68
+ "invalid_format", # Random string
69
+ "id:wrong_format", # Almost correct but wrong separator
70
+ "fine_tune:wrong_format", # Almost correct but wrong prefix
71
+ ":::", # Just separators
72
+ ]
73
+
74
+ for invalid_id in invalid_ids:
75
+ with pytest.raises(ValidationError, match="Invalid prompt ID"):
76
+ ModelTester(prompt_id=invalid_id)
77
+
78
+
79
+ def test_prompt_generator_case_sensitivity():
80
+ """Test that prompt generator names are case sensitive"""
81
+ # Take first generator and modify its case
82
+ first_generator = next(iter(PromptGenerators)).value
83
+ wrong_case = first_generator.upper()
84
+ if wrong_case == first_generator:
85
+ wrong_case = first_generator.lower()
86
+
87
+ with pytest.raises(ValidationError):
88
+ ModelTester(prompt_id=wrong_case)
89
+
90
+
91
+ @pytest.mark.parametrize(
92
+ "valid_id",
93
+ [
94
+ "task_run_config::project_123::task_456::config_123", # Valid task run config prompt ID
95
+ ],
96
+ )
97
+ def test_valid_task_run_config_prompt_id(valid_id):
98
+ """Test that valid eval prompt IDs are accepted"""
99
+ model = ModelTester(prompt_id=valid_id)
100
+ assert model.prompt_id == valid_id
101
+
102
+
103
+ @pytest.mark.parametrize(
104
+ "invalid_id,expected_error",
105
+ [
106
+ ("task_run_config::", "Invalid task run config prompt ID"),
107
+ ("task_run_config::p1", "Invalid task run config prompt ID"),
108
+ ("task_run_config::p1::t1", "Invalid task run config prompt ID"),
109
+ ("task_run_config::p1::t1::c1::extra", "Invalid task run config prompt ID"),
110
+ ],
111
+ )
112
+ def test_invalid_eval_prompt_id_format(invalid_id, expected_error):
113
+ """Test that invalid eval prompt ID formats are rejected"""
114
+ with pytest.raises(ValidationError, match=expected_error):
115
+ ModelTester(prompt_id=invalid_id)
116
+
117
+
118
+ @pytest.mark.parametrize(
119
+ "id,should_be_frozen",
120
+ [
121
+ ("simple_prompt_builder", False),
122
+ ("id::prompt_123", True),
123
+ ("task_run_config::p1::t1", True),
124
+ ("fine_tune_prompt::ft_123", True),
125
+ ],
126
+ )
127
+ def test_is_frozen_prompt(id, should_be_frozen):
128
+ """Test that the is_frozen_prompt function works"""
129
+ assert is_frozen_prompt(id) == should_be_frozen