kiln-ai 0.8.0__py3-none-any.whl → 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +7 -7
- kiln_ai/adapters/adapter_registry.py +77 -5
- kiln_ai/adapters/data_gen/data_gen_task.py +3 -3
- kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
- kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
- kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
- kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
- kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +469 -129
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +113 -21
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
- kiln_ai/adapters/ml_model_list.py +323 -94
- kiln_ai/adapters/model_adapters/__init__.py +18 -0
- kiln_ai/adapters/{base_adapter.py → model_adapters/base_adapter.py} +81 -37
- kiln_ai/adapters/{langchain_adapters.py → model_adapters/langchain_adapters.py} +130 -84
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +11 -0
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +246 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +190 -0
- kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +103 -88
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +225 -0
- kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +43 -15
- kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +93 -20
- kiln_ai/adapters/parsers/__init__.py +10 -0
- kiln_ai/adapters/parsers/base_parser.py +12 -0
- kiln_ai/adapters/parsers/json_parser.py +37 -0
- kiln_ai/adapters/parsers/parser_registry.py +19 -0
- kiln_ai/adapters/parsers/r1_parser.py +69 -0
- kiln_ai/adapters/parsers/test_json_parser.py +81 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
- kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
- kiln_ai/adapters/prompt_builders.py +126 -20
- kiln_ai/adapters/provider_tools.py +91 -36
- kiln_ai/adapters/repair/repair_task.py +17 -6
- kiln_ai/adapters/repair/test_repair_task.py +4 -4
- kiln_ai/adapters/run_output.py +8 -0
- kiln_ai/adapters/test_adapter_registry.py +177 -0
- kiln_ai/adapters/test_generate_docs.py +69 -0
- kiln_ai/adapters/test_prompt_adaptors.py +8 -4
- kiln_ai/adapters/test_prompt_builders.py +190 -29
- kiln_ai/adapters/test_provider_tools.py +268 -46
- kiln_ai/datamodel/__init__.py +199 -12
- kiln_ai/datamodel/basemodel.py +31 -11
- kiln_ai/datamodel/json_schema.py +8 -3
- kiln_ai/datamodel/model_cache.py +8 -3
- kiln_ai/datamodel/test_basemodel.py +81 -2
- kiln_ai/datamodel/test_dataset_split.py +100 -3
- kiln_ai/datamodel/test_example_models.py +25 -4
- kiln_ai/datamodel/test_model_cache.py +24 -0
- kiln_ai/datamodel/test_model_perf.py +125 -0
- kiln_ai/datamodel/test_models.py +129 -0
- kiln_ai/utils/exhaustive_error.py +6 -0
- {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/METADATA +9 -7
- kiln_ai-0.11.1.dist-info/RECORD +76 -0
- kiln_ai-0.8.0.dist-info/RECORD +0 -58
- {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.8.0.dist-info → kiln_ai-0.11.1.dist-info}/licenses/LICENSE.txt +0 -0
kiln_ai/datamodel/__init__.py
CHANGED
|
@@ -49,11 +49,18 @@ __all__ = [
|
|
|
49
49
|
"DataSource",
|
|
50
50
|
"DataSourceType",
|
|
51
51
|
"DataSourceProperty",
|
|
52
|
+
"Finetune",
|
|
53
|
+
"FineTuneStatusType",
|
|
52
54
|
"TaskOutputRatingType",
|
|
53
55
|
"TaskRequirement",
|
|
54
56
|
"TaskDeterminism",
|
|
57
|
+
"DatasetSplitDefinition",
|
|
58
|
+
"DatasetSplit",
|
|
59
|
+
"RequirementRating",
|
|
60
|
+
"TaskRequirement",
|
|
55
61
|
"strict_mode",
|
|
56
62
|
"set_strict_mode",
|
|
63
|
+
"Prompt",
|
|
57
64
|
]
|
|
58
65
|
|
|
59
66
|
|
|
@@ -268,12 +275,47 @@ class FineTuneStatusType(str, Enum):
|
|
|
268
275
|
failed = "failed"
|
|
269
276
|
|
|
270
277
|
|
|
278
|
+
class StructuredOutputMode(str, Enum):
|
|
279
|
+
"""
|
|
280
|
+
Enumeration of supported structured output modes.
|
|
281
|
+
|
|
282
|
+
- default: let the adapter decide
|
|
283
|
+
- json_schema: request json using API capabilities for json_schema
|
|
284
|
+
- function_calling: request json using API capabilities for function calling
|
|
285
|
+
- json_mode: request json using API's JSON mode, which should return valid JSON, but isn't checking/passing the schema
|
|
286
|
+
- json_instructions: append instructions to the prompt to request json matching the schema. No API capabilities are used. You should have a custom parser on these models as they will be returning strings.
|
|
287
|
+
- json_instruction_and_object: append instructions to the prompt to request json matching the schema. Also request the response as json_mode via API capabilities (returning dictionaries).
|
|
288
|
+
"""
|
|
289
|
+
|
|
290
|
+
default = "default"
|
|
291
|
+
json_schema = "json_schema"
|
|
292
|
+
function_calling = "function_calling"
|
|
293
|
+
json_mode = "json_mode"
|
|
294
|
+
json_instructions = "json_instructions"
|
|
295
|
+
json_instruction_and_object = "json_instruction_and_object"
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class FinetuneDataStrategy(str, Enum):
|
|
299
|
+
final_only = "final_only"
|
|
300
|
+
final_and_intermediate = "final_and_intermediate"
|
|
301
|
+
|
|
302
|
+
|
|
271
303
|
class Finetune(KilnParentedModel):
|
|
304
|
+
"""
|
|
305
|
+
The Kiln fine-tune datamodel.
|
|
306
|
+
|
|
307
|
+
Initially holds a reference to a training job, with needed identifiers to update the status. When complete, contains the new model ID.
|
|
308
|
+
"""
|
|
309
|
+
|
|
272
310
|
name: str = NAME_FIELD
|
|
273
311
|
description: str | None = Field(
|
|
274
312
|
default=None,
|
|
275
313
|
description="A description of the fine-tune for you and your team. Not used in training.",
|
|
276
314
|
)
|
|
315
|
+
structured_output_mode: StructuredOutputMode | None = Field(
|
|
316
|
+
default=None,
|
|
317
|
+
description="The mode to use to train the model for structured output, if it was trained with structured output. Will determine how we call the tuned model, so we call with the matching mode.",
|
|
318
|
+
)
|
|
277
319
|
provider: str = Field(
|
|
278
320
|
description="The provider to use for the fine-tune (e.g. 'openai')."
|
|
279
321
|
)
|
|
@@ -303,9 +345,14 @@ class Finetune(KilnParentedModel):
|
|
|
303
345
|
default={},
|
|
304
346
|
description="The parameters to use for this fine-tune. These are provider-specific.",
|
|
305
347
|
)
|
|
348
|
+
# These two fields are saved exactly used for training. Even if they map exactly to a custom prompt or generator, those can change, so we want to keep a record of the training prompt.
|
|
306
349
|
system_message: str = Field(
|
|
307
350
|
description="The system message to use for this fine-tune.",
|
|
308
351
|
)
|
|
352
|
+
thinking_instructions: str | None = Field(
|
|
353
|
+
default=None,
|
|
354
|
+
description="The thinking instructions to use for this fine-tune. Only used when data_strategy is final_and_intermediate.",
|
|
355
|
+
)
|
|
309
356
|
latest_status: FineTuneStatusType = Field(
|
|
310
357
|
default=FineTuneStatusType.unknown,
|
|
311
358
|
description="The latest known status of this fine-tune. Not updated in real time.",
|
|
@@ -314,12 +361,34 @@ class Finetune(KilnParentedModel):
|
|
|
314
361
|
default={},
|
|
315
362
|
description="Properties of the fine-tune. Different providers may use different properties.",
|
|
316
363
|
)
|
|
364
|
+
data_strategy: FinetuneDataStrategy = Field(
|
|
365
|
+
default=FinetuneDataStrategy.final_only,
|
|
366
|
+
description="The strategy to use for training the model. 'final_only' will only train on the final response. 'final_and_intermediate' will train on the final response and intermediate outputs (chain of thought or reasoning).",
|
|
367
|
+
)
|
|
317
368
|
|
|
318
369
|
def parent_task(self) -> Task | None:
|
|
319
370
|
if not isinstance(self.parent, Task):
|
|
320
371
|
return None
|
|
321
372
|
return self.parent
|
|
322
373
|
|
|
374
|
+
@model_validator(mode="after")
|
|
375
|
+
def validate_thinking_instructions(self) -> Self:
|
|
376
|
+
if (
|
|
377
|
+
self.thinking_instructions is not None
|
|
378
|
+
and self.data_strategy != FinetuneDataStrategy.final_and_intermediate
|
|
379
|
+
):
|
|
380
|
+
raise ValueError(
|
|
381
|
+
"Thinking instructions can only be used when data_strategy is final_and_intermediate"
|
|
382
|
+
)
|
|
383
|
+
if (
|
|
384
|
+
self.thinking_instructions is None
|
|
385
|
+
and self.data_strategy == FinetuneDataStrategy.final_and_intermediate
|
|
386
|
+
):
|
|
387
|
+
raise ValueError(
|
|
388
|
+
"Thinking instructions are required when data_strategy is final_and_intermediate"
|
|
389
|
+
)
|
|
390
|
+
return self
|
|
391
|
+
|
|
323
392
|
|
|
324
393
|
class DataSourceType(str, Enum):
|
|
325
394
|
"""
|
|
@@ -391,6 +460,13 @@ class DataSource(BaseModel):
|
|
|
391
460
|
type=str,
|
|
392
461
|
not_allowed_for=[DataSourceType.human],
|
|
393
462
|
),
|
|
463
|
+
DataSourceProperty(
|
|
464
|
+
# Optional: an ID within the scope of the prompt_builder_name.
|
|
465
|
+
# Used for prompt builders with IDs (like saved prompts, fine-tune prompts)
|
|
466
|
+
name="prompt_id",
|
|
467
|
+
type=str,
|
|
468
|
+
not_allowed_for=[DataSourceType.human],
|
|
469
|
+
),
|
|
394
470
|
]
|
|
395
471
|
|
|
396
472
|
@model_validator(mode="after")
|
|
@@ -464,13 +540,39 @@ class TaskRun(KilnParentedModel):
|
|
|
464
540
|
description="Tags for the task run. Tags are used to categorize task runs for filtering and reporting.",
|
|
465
541
|
)
|
|
466
542
|
|
|
543
|
+
def has_thinking_training_data(self) -> bool:
|
|
544
|
+
"""
|
|
545
|
+
Does this run have thinking data that we can use to train a thinking model?
|
|
546
|
+
"""
|
|
547
|
+
if self.intermediate_outputs is None:
|
|
548
|
+
return False
|
|
549
|
+
return (
|
|
550
|
+
"chain_of_thought" in self.intermediate_outputs
|
|
551
|
+
or "reasoning" in self.intermediate_outputs
|
|
552
|
+
)
|
|
553
|
+
|
|
467
554
|
def parent_task(self) -> Task | None:
|
|
468
555
|
if not isinstance(self.parent, Task):
|
|
469
556
|
return None
|
|
470
557
|
return self.parent
|
|
471
558
|
|
|
472
559
|
@model_validator(mode="after")
|
|
473
|
-
def validate_input_format(self) -> Self:
|
|
560
|
+
def validate_input_format(self, info: ValidationInfo) -> Self:
|
|
561
|
+
# Don't validate if loading from file (not new). Too slow.
|
|
562
|
+
# We don't allow changing task schema, so this is redundant validation.
|
|
563
|
+
# Note: we still validate if editing a loaded model
|
|
564
|
+
if self.loading_from_file(info):
|
|
565
|
+
# Consider loading an existing model as validated.
|
|
566
|
+
self._last_validated_input = self.input
|
|
567
|
+
return self
|
|
568
|
+
|
|
569
|
+
# Don't validate if input has not changed. Too slow to run this every time.
|
|
570
|
+
if (
|
|
571
|
+
hasattr(self, "_last_validated_input")
|
|
572
|
+
and self.input == self._last_validated_input
|
|
573
|
+
):
|
|
574
|
+
return self
|
|
575
|
+
|
|
474
576
|
task = self.parent_task()
|
|
475
577
|
if task is None:
|
|
476
578
|
# don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving)
|
|
@@ -484,15 +586,33 @@ class TaskRun(KilnParentedModel):
|
|
|
484
586
|
raise ValueError("Input is not a valid JSON object")
|
|
485
587
|
except jsonschema.exceptions.ValidationError as e:
|
|
486
588
|
raise ValueError(f"Input does not match task input schema: {e}")
|
|
589
|
+
self._last_validated_input = self.input
|
|
487
590
|
return self
|
|
488
591
|
|
|
489
592
|
@model_validator(mode="after")
|
|
490
|
-
def validate_output_format(self) -> Self:
|
|
593
|
+
def validate_output_format(self, info: ValidationInfo) -> Self:
|
|
594
|
+
# Don't validate if loading from file (not new). Too slow.
|
|
595
|
+
# Note: we still validate if editing a loaded model's output.
|
|
596
|
+
if self.loading_from_file(info):
|
|
597
|
+
# Consider loading an existing model as validated.
|
|
598
|
+
self._last_validated_output = self.output.output if self.output else None
|
|
599
|
+
return self
|
|
600
|
+
|
|
601
|
+
# Don't validate unless output has changed since last validation.
|
|
602
|
+
# The validator is slow and costly, don't want it running when setting other fields.
|
|
603
|
+
if (
|
|
604
|
+
hasattr(self, "_last_validated_output")
|
|
605
|
+
and self.output is not None
|
|
606
|
+
and self.output.output == self._last_validated_output
|
|
607
|
+
):
|
|
608
|
+
return self
|
|
609
|
+
|
|
491
610
|
task = self.parent_task()
|
|
492
611
|
if task is None:
|
|
493
612
|
return self
|
|
494
613
|
|
|
495
614
|
self.output.validate_output_format(task)
|
|
615
|
+
self._last_validated_output = self.output.output if self.output else None
|
|
496
616
|
return self
|
|
497
617
|
|
|
498
618
|
@model_validator(mode="after")
|
|
@@ -544,11 +664,47 @@ def AllDatasetFilter(_: TaskRun) -> bool:
|
|
|
544
664
|
|
|
545
665
|
|
|
546
666
|
def HighRatingDatasetFilter(task_run: TaskRun) -> bool:
|
|
547
|
-
if task_run.output is None
|
|
667
|
+
if task_run.output is None:
|
|
668
|
+
return False
|
|
669
|
+
if task_run.repaired_output is not None:
|
|
670
|
+
# Repairs always considered high quality
|
|
671
|
+
return True
|
|
672
|
+
if task_run.output.rating is None:
|
|
548
673
|
return False
|
|
549
674
|
return task_run.output.rating.is_high_quality()
|
|
550
675
|
|
|
551
676
|
|
|
677
|
+
def ThinkingModelDatasetFilter(task_run: TaskRun) -> bool:
|
|
678
|
+
"""
|
|
679
|
+
A filter that returns True if the task has intermediate outputs we can training a 'thinking' model on (reasoning or chain of thought)
|
|
680
|
+
"""
|
|
681
|
+
return task_run.has_thinking_training_data()
|
|
682
|
+
|
|
683
|
+
|
|
684
|
+
def ThinkingModelHighRatedFilter(task_run: TaskRun) -> bool:
|
|
685
|
+
"""
|
|
686
|
+
A filter that returns True if the task has thinking data and the output is high quality
|
|
687
|
+
"""
|
|
688
|
+
return ThinkingModelDatasetFilter(task_run) and HighRatingDatasetFilter(task_run)
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
class DatasetFilterType(str, Enum):
|
|
692
|
+
"""Dataset filter names."""
|
|
693
|
+
|
|
694
|
+
ALL = "all"
|
|
695
|
+
HIGH_RATING = "high_rating"
|
|
696
|
+
THINKING_MODEL = "thinking_model"
|
|
697
|
+
THINKING_MODEL_HIGH_RATED = "thinking_model_high_rated"
|
|
698
|
+
|
|
699
|
+
|
|
700
|
+
dataset_filters = {
|
|
701
|
+
DatasetFilterType.ALL: AllDatasetFilter,
|
|
702
|
+
DatasetFilterType.HIGH_RATING: HighRatingDatasetFilter,
|
|
703
|
+
DatasetFilterType.THINKING_MODEL: ThinkingModelDatasetFilter,
|
|
704
|
+
DatasetFilterType.THINKING_MODEL_HIGH_RATED: ThinkingModelHighRatedFilter,
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
|
|
552
708
|
class DatasetSplitDefinition(BaseModel):
|
|
553
709
|
"""
|
|
554
710
|
A definition of a split in a dataset.
|
|
@@ -580,6 +736,11 @@ Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [
|
|
|
580
736
|
DatasetSplitDefinition(name="test", percentage=0.2),
|
|
581
737
|
DatasetSplitDefinition(name="val", percentage=0.2),
|
|
582
738
|
]
|
|
739
|
+
Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [
|
|
740
|
+
DatasetSplitDefinition(name="train", percentage=0.8),
|
|
741
|
+
DatasetSplitDefinition(name="test", percentage=0.1),
|
|
742
|
+
DatasetSplitDefinition(name="val", percentage=0.1),
|
|
743
|
+
]
|
|
583
744
|
|
|
584
745
|
|
|
585
746
|
class DatasetSplit(KilnParentedModel):
|
|
@@ -603,6 +764,10 @@ class DatasetSplit(KilnParentedModel):
|
|
|
603
764
|
split_contents: dict[str, list[str]] = Field(
|
|
604
765
|
description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
|
|
605
766
|
)
|
|
767
|
+
filter: DatasetFilterType | None = Field(
|
|
768
|
+
default=None,
|
|
769
|
+
description="The filter used to build the dataset.",
|
|
770
|
+
)
|
|
606
771
|
|
|
607
772
|
@model_validator(mode="after")
|
|
608
773
|
def validate_split_percentages(self) -> "DatasetSplit":
|
|
@@ -617,12 +782,13 @@ class DatasetSplit(KilnParentedModel):
|
|
|
617
782
|
name: str,
|
|
618
783
|
task: "Task",
|
|
619
784
|
splits: list[DatasetSplitDefinition],
|
|
620
|
-
|
|
785
|
+
filter_type: DatasetFilterType = DatasetFilterType.ALL,
|
|
621
786
|
description: str | None = None,
|
|
622
787
|
):
|
|
623
788
|
"""
|
|
624
789
|
Build a dataset split from a task.
|
|
625
790
|
"""
|
|
791
|
+
filter = dataset_filters[filter_type]
|
|
626
792
|
split_contents = cls.build_split_contents(task, splits, filter)
|
|
627
793
|
return cls(
|
|
628
794
|
parent=task,
|
|
@@ -630,6 +796,7 @@ class DatasetSplit(KilnParentedModel):
|
|
|
630
796
|
description=description,
|
|
631
797
|
splits=splits,
|
|
632
798
|
split_contents=split_contents,
|
|
799
|
+
filter=filter_type,
|
|
633
800
|
)
|
|
634
801
|
|
|
635
802
|
@classmethod
|
|
@@ -680,7 +847,7 @@ class DatasetSplit(KilnParentedModel):
|
|
|
680
847
|
if parent is None:
|
|
681
848
|
raise ValueError("DatasetSplit has no parent task")
|
|
682
849
|
|
|
683
|
-
runs = parent.runs()
|
|
850
|
+
runs = parent.runs(readonly=True)
|
|
684
851
|
all_ids = set(run.id for run in runs)
|
|
685
852
|
all_ids_in_splits = set()
|
|
686
853
|
for ids in self.split_contents.values():
|
|
@@ -689,6 +856,22 @@ class DatasetSplit(KilnParentedModel):
|
|
|
689
856
|
return len(missing)
|
|
690
857
|
|
|
691
858
|
|
|
859
|
+
class Prompt(KilnParentedModel):
|
|
860
|
+
"""
|
|
861
|
+
A prompt for a task.
|
|
862
|
+
"""
|
|
863
|
+
|
|
864
|
+
name: str = NAME_FIELD
|
|
865
|
+
prompt: str = Field(
|
|
866
|
+
description="The prompt for the task.",
|
|
867
|
+
min_length=1,
|
|
868
|
+
)
|
|
869
|
+
chain_of_thought_instructions: str | None = Field(
|
|
870
|
+
default=None,
|
|
871
|
+
description="Instructions for the model 'thinking' about the requirement prior to answering. Used for chain of thought style prompting. COT will not be used unless this is provided.",
|
|
872
|
+
)
|
|
873
|
+
|
|
874
|
+
|
|
692
875
|
class TaskRequirement(BaseModel):
|
|
693
876
|
"""
|
|
694
877
|
Defines a specific requirement that should be met by task outputs.
|
|
@@ -726,6 +909,7 @@ class Task(
|
|
|
726
909
|
"runs": TaskRun,
|
|
727
910
|
"dataset_splits": DatasetSplit,
|
|
728
911
|
"finetunes": Finetune,
|
|
912
|
+
"prompts": Prompt,
|
|
729
913
|
},
|
|
730
914
|
):
|
|
731
915
|
"""
|
|
@@ -762,15 +946,18 @@ class Task(
|
|
|
762
946
|
return None
|
|
763
947
|
return schema_from_json_str(self.input_json_schema)
|
|
764
948
|
|
|
765
|
-
#
|
|
766
|
-
def runs(self) -> list[TaskRun]:
|
|
767
|
-
return super().runs() # type: ignore
|
|
949
|
+
# These wrappers help for typechecking. TODO P2: fix this in KilnParentModel
|
|
950
|
+
def runs(self, readonly: bool = False) -> list[TaskRun]:
|
|
951
|
+
return super().runs(readonly=readonly) # type: ignore
|
|
952
|
+
|
|
953
|
+
def dataset_splits(self, readonly: bool = False) -> list[DatasetSplit]:
|
|
954
|
+
return super().dataset_splits(readonly=readonly) # type: ignore
|
|
768
955
|
|
|
769
|
-
def
|
|
770
|
-
return super().
|
|
956
|
+
def finetunes(self, readonly: bool = False) -> list[Finetune]:
|
|
957
|
+
return super().finetunes(readonly=readonly) # type: ignore
|
|
771
958
|
|
|
772
|
-
def
|
|
773
|
-
return super().
|
|
959
|
+
def prompts(self, readonly: bool = False) -> list[Prompt]:
|
|
960
|
+
return super().prompts(readonly=readonly) # type: ignore
|
|
774
961
|
|
|
775
962
|
|
|
776
963
|
class Project(KilnParentModel, parent_of={"tasks": Task}):
|
kiln_ai/datamodel/basemodel.py
CHANGED
|
@@ -120,11 +120,12 @@ class KilnBaseModel(BaseModel):
|
|
|
120
120
|
return cls.load_from_file(path)
|
|
121
121
|
|
|
122
122
|
@classmethod
|
|
123
|
-
def load_from_file(cls: Type[T], path: Path | str) -> T:
|
|
123
|
+
def load_from_file(cls: Type[T], path: Path | str, readonly: bool = False) -> T:
|
|
124
124
|
"""Load a model instance from a specific file path.
|
|
125
125
|
|
|
126
126
|
Args:
|
|
127
127
|
path (Path): Path to the model file
|
|
128
|
+
readonly (bool): If True, the model will be returned in readonly mode (cached instance, not a copy, not safe to mutate)
|
|
128
129
|
|
|
129
130
|
Returns:
|
|
130
131
|
T: Instance of the model
|
|
@@ -135,10 +136,10 @@ class KilnBaseModel(BaseModel):
|
|
|
135
136
|
"""
|
|
136
137
|
if isinstance(path, str):
|
|
137
138
|
path = Path(path)
|
|
138
|
-
cached_model = ModelCache.shared().get_model(path, cls)
|
|
139
|
+
cached_model = ModelCache.shared().get_model(path, cls, readonly=readonly)
|
|
139
140
|
if cached_model is not None:
|
|
140
141
|
return cached_model
|
|
141
|
-
with open(path, "r") as file:
|
|
142
|
+
with open(path, "r", encoding="utf-8") as file:
|
|
142
143
|
# modified time of file for cache invalidation. From file descriptor so it's atomic w read.
|
|
143
144
|
mtime_ns = os.fstat(file.fileno()).st_mtime_ns
|
|
144
145
|
file_data = file.read()
|
|
@@ -168,13 +169,20 @@ class KilnBaseModel(BaseModel):
|
|
|
168
169
|
# Two methods of indicated it's loaded from file:
|
|
169
170
|
# 1) info.context.get("loading_from_file") -> During actual loading, before we can set _loaded_from_file
|
|
170
171
|
# 2) self._loaded_from_file -> After loading, set by the loader
|
|
172
|
+
if self.loading_from_file(info):
|
|
173
|
+
return True
|
|
174
|
+
return self._loaded_from_file
|
|
175
|
+
|
|
176
|
+
# indicates the model is currently being loaded from file (not mutating it after)
|
|
177
|
+
def loading_from_file(self, info: ValidationInfo | None = None) -> bool:
|
|
178
|
+
# info.context.get("loading_from_file") -> During actual loading, before we can set _loaded_from_file
|
|
171
179
|
if (
|
|
172
180
|
info is not None
|
|
173
181
|
and info.context is not None
|
|
174
182
|
and info.context.get("loading_from_file", False)
|
|
175
183
|
):
|
|
176
184
|
return True
|
|
177
|
-
return
|
|
185
|
+
return False
|
|
178
186
|
|
|
179
187
|
def save_to_file(self) -> None:
|
|
180
188
|
"""Save the model instance to a file.
|
|
@@ -190,7 +198,7 @@ class KilnBaseModel(BaseModel):
|
|
|
190
198
|
)
|
|
191
199
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
192
200
|
json_data = self.model_dump_json(indent=2, exclude={"path"})
|
|
193
|
-
with open(path, "w") as file:
|
|
201
|
+
with open(path, "w", encoding="utf-8") as file:
|
|
194
202
|
file.write(json_data)
|
|
195
203
|
# save the path so even if something like name changes, the file doesn't move
|
|
196
204
|
self.path = path
|
|
@@ -342,16 +350,28 @@ class KilnParentedModel(KilnBaseModel, metaclass=ABCMeta):
|
|
|
342
350
|
return []
|
|
343
351
|
|
|
344
352
|
# Collect all /relationship/{id}/{base_filename.kiln} files in the relationship folder
|
|
345
|
-
for
|
|
346
|
-
|
|
353
|
+
# manual code instead of glob for performance (5x speedup over glob)
|
|
354
|
+
|
|
355
|
+
base_filename = cls.base_filename()
|
|
356
|
+
# Iterate through immediate subdirectories using scandir for better performance
|
|
357
|
+
# Benchmark: scandir is 10x faster than glob, so worth the extra code
|
|
358
|
+
with os.scandir(relationship_folder) as entries:
|
|
359
|
+
for entry in entries:
|
|
360
|
+
if not entry.is_dir():
|
|
361
|
+
continue
|
|
362
|
+
|
|
363
|
+
child_file = Path(entry.path) / base_filename
|
|
364
|
+
if child_file.is_file():
|
|
365
|
+
yield child_file
|
|
347
366
|
|
|
348
367
|
@classmethod
|
|
349
368
|
def all_children_of_parent_path(
|
|
350
|
-
cls: Type[PT], parent_path: Path | None
|
|
369
|
+
cls: Type[PT], parent_path: Path | None, readonly: bool = False
|
|
351
370
|
) -> list[PT]:
|
|
352
371
|
children = []
|
|
353
372
|
for child_path in cls.iterate_children_paths_of_parent_path(parent_path):
|
|
354
|
-
|
|
373
|
+
item = cls.load_from_file(child_path, readonly=readonly)
|
|
374
|
+
children.append(item)
|
|
355
375
|
return children
|
|
356
376
|
|
|
357
377
|
@classmethod
|
|
@@ -394,8 +414,8 @@ class KilnParentModel(KilnBaseModel, metaclass=ABCMeta):
|
|
|
394
414
|
def _create_child_method(
|
|
395
415
|
cls, relationship_name: str, child_class: Type[KilnParentedModel]
|
|
396
416
|
):
|
|
397
|
-
def child_method(self) -> list[child_class]:
|
|
398
|
-
return child_class.all_children_of_parent_path(self.path)
|
|
417
|
+
def child_method(self, readonly: bool = False) -> list[child_class]:
|
|
418
|
+
return child_class.all_children_of_parent_path(self.path, readonly=readonly)
|
|
399
419
|
|
|
400
420
|
child_method.__name__ = relationship_name
|
|
401
421
|
child_method.__annotations__ = {"return": List[child_class]}
|
kiln_ai/datamodel/json_schema.py
CHANGED
|
@@ -42,9 +42,14 @@ def validate_schema(instance: Dict, schema_str: str) -> None:
|
|
|
42
42
|
jsonschema.exceptions.ValidationError: If validation fails
|
|
43
43
|
ValueError: If the schema is invalid
|
|
44
44
|
"""
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
45
|
+
try:
|
|
46
|
+
schema = schema_from_json_str(schema_str)
|
|
47
|
+
v = jsonschema.Draft202012Validator(schema)
|
|
48
|
+
v.validate(instance)
|
|
49
|
+
except jsonschema.exceptions.ValidationError as e:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
f"This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema. Search 'Troubleshooting Structured Data Issues' in our docs for more information. The error from the schema check was: {e.message}"
|
|
52
|
+
) from e
|
|
48
53
|
|
|
49
54
|
|
|
50
55
|
def schema_from_json_str(v: str) -> Dict:
|
kiln_ai/datamodel/model_cache.py
CHANGED
|
@@ -62,12 +62,17 @@ class ModelCache:
|
|
|
62
62
|
raise ValueError(f"Model at {path} is not of type {model_type.__name__}")
|
|
63
63
|
return model
|
|
64
64
|
|
|
65
|
-
def get_model(
|
|
66
|
-
|
|
65
|
+
def get_model(
|
|
66
|
+
self, path: Path, model_type: Type[T], readonly: bool = False
|
|
67
|
+
) -> Optional[T]:
|
|
68
|
+
# We return a copy by default, so in-memory edits don't impact the cache until they are saved
|
|
67
69
|
# Benchmark shows about 2x slower, but much more foolproof
|
|
68
70
|
model = self._get_model(path, model_type)
|
|
69
71
|
if model:
|
|
70
|
-
|
|
72
|
+
if readonly:
|
|
73
|
+
return model
|
|
74
|
+
else:
|
|
75
|
+
return model.model_copy(deep=True)
|
|
71
76
|
return None
|
|
72
77
|
|
|
73
78
|
def get_model_id(self, path: Path, model_type: Type[T]) -> Optional[str]:
|
|
@@ -6,6 +6,9 @@ from unittest.mock import MagicMock, patch
|
|
|
6
6
|
|
|
7
7
|
import pytest
|
|
8
8
|
|
|
9
|
+
from kiln_ai.adapters.model_adapters.base_adapter import AdapterInfo, BaseAdapter
|
|
10
|
+
from kiln_ai.adapters.run_output import RunOutput
|
|
11
|
+
from kiln_ai.datamodel import Task, TaskRun
|
|
9
12
|
from kiln_ai.datamodel.basemodel import (
|
|
10
13
|
KilnBaseModel,
|
|
11
14
|
KilnParentedModel,
|
|
@@ -356,7 +359,9 @@ def test_load_from_file_with_cache(test_base_file, tmp_model_cache):
|
|
|
356
359
|
model = KilnBaseModel.load_from_file(test_base_file)
|
|
357
360
|
|
|
358
361
|
# Check that the cache was checked and set
|
|
359
|
-
tmp_model_cache.get_model.assert_called_once_with(
|
|
362
|
+
tmp_model_cache.get_model.assert_called_once_with(
|
|
363
|
+
test_base_file, KilnBaseModel, readonly=False
|
|
364
|
+
)
|
|
360
365
|
tmp_model_cache.set_model.assert_called_once()
|
|
361
366
|
|
|
362
367
|
# Ensure the model is correctly loaded
|
|
@@ -407,7 +412,9 @@ def test_load_from_file_with_cached_model(test_base_file, tmp_model_cache):
|
|
|
407
412
|
model = KilnBaseModel.load_from_file(test_base_file)
|
|
408
413
|
|
|
409
414
|
# Check that the cache was checked and the cached model was returned
|
|
410
|
-
tmp_model_cache.get_model.assert_called_once_with(
|
|
415
|
+
tmp_model_cache.get_model.assert_called_once_with(
|
|
416
|
+
test_base_file, KilnBaseModel, readonly=False
|
|
417
|
+
)
|
|
411
418
|
assert model is cached_model
|
|
412
419
|
|
|
413
420
|
# Assert that open was not called (we used the cached model, not file)
|
|
@@ -469,3 +476,75 @@ def test_from_id_and_parent_path_without_parent():
|
|
|
469
476
|
# Test with None parent_path
|
|
470
477
|
not_found = DefaultParentedModel.from_id_and_parent_path("any-id", None)
|
|
471
478
|
assert not_found is None
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
class MockAdapter(BaseAdapter):
|
|
482
|
+
"""Implementation of BaseAdapter for testing"""
|
|
483
|
+
|
|
484
|
+
async def _run(self, input):
|
|
485
|
+
return RunOutput(output="test output", intermediate_outputs=None)
|
|
486
|
+
|
|
487
|
+
def adapter_info(self) -> AdapterInfo:
|
|
488
|
+
return AdapterInfo(
|
|
489
|
+
adapter_name="test",
|
|
490
|
+
model_name=self.model_name,
|
|
491
|
+
model_provider=self.model_provider_name,
|
|
492
|
+
prompt_builder_name="test",
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
@pytest.fixture
|
|
497
|
+
def base_task():
|
|
498
|
+
return Task(name="test_task", instruction="test_instruction")
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
@pytest.fixture
|
|
502
|
+
def adapter(base_task):
|
|
503
|
+
return MockAdapter(
|
|
504
|
+
kiln_task=base_task,
|
|
505
|
+
model_name="test_model",
|
|
506
|
+
model_provider_name="test_provider",
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
async def test_invoke_parsing_flow(adapter):
|
|
511
|
+
# Mock dependencies
|
|
512
|
+
mock_provider = MagicMock()
|
|
513
|
+
mock_provider.parser = "test_parser"
|
|
514
|
+
|
|
515
|
+
mock_parser = MagicMock()
|
|
516
|
+
mock_parser.parse_output.return_value = RunOutput(
|
|
517
|
+
output="parsed test output", intermediate_outputs={"key": "value"}
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
mock_parser_class = MagicMock(return_value=mock_parser)
|
|
521
|
+
|
|
522
|
+
with (
|
|
523
|
+
patch.object(adapter, "model_provider", return_value=mock_provider),
|
|
524
|
+
patch(
|
|
525
|
+
"kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id",
|
|
526
|
+
return_value=mock_parser_class,
|
|
527
|
+
),
|
|
528
|
+
patch("kiln_ai.adapters.model_adapters.base_adapter.Config") as mock_config,
|
|
529
|
+
):
|
|
530
|
+
# Disable autosaving for this test
|
|
531
|
+
mock_config.shared.return_value.autosave_runs = False
|
|
532
|
+
mock_config.shared.return_value.user_id = "test_user_id"
|
|
533
|
+
|
|
534
|
+
# Execute
|
|
535
|
+
result = await adapter.invoke("test input")
|
|
536
|
+
|
|
537
|
+
# Verify parser was created correctly
|
|
538
|
+
mock_parser_class.assert_called_once_with(structured_output=False)
|
|
539
|
+
|
|
540
|
+
# Verify parsing occurred
|
|
541
|
+
mock_parser.parse_output.assert_called_once()
|
|
542
|
+
parsed_args = mock_parser.parse_output.call_args[1]
|
|
543
|
+
assert isinstance(parsed_args["original_output"], RunOutput)
|
|
544
|
+
assert parsed_args["original_output"].output == "test output"
|
|
545
|
+
|
|
546
|
+
# Verify result contains parsed output
|
|
547
|
+
assert isinstance(result, TaskRun)
|
|
548
|
+
assert result.output.output == "parsed test output"
|
|
549
|
+
assert result.intermediate_outputs == {"key": "value"}
|
|
550
|
+
assert result.input == "test input"
|