kiln-ai 0.8.1__py3-none-any.whl → 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +7 -7
- kiln_ai/adapters/adapter_registry.py +77 -5
- kiln_ai/adapters/data_gen/data_gen_task.py +3 -3
- kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
- kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
- kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
- kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
- kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
- kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +469 -129
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +113 -21
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
- kiln_ai/adapters/ml_model_list.py +323 -94
- kiln_ai/adapters/model_adapters/__init__.py +18 -0
- kiln_ai/adapters/{base_adapter.py → model_adapters/base_adapter.py} +81 -37
- kiln_ai/adapters/{langchain_adapters.py → model_adapters/langchain_adapters.py} +130 -84
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +11 -0
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +246 -0
- kiln_ai/adapters/model_adapters/test_base_adapter.py +190 -0
- kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +103 -88
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +225 -0
- kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +43 -15
- kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +93 -20
- kiln_ai/adapters/parsers/__init__.py +10 -0
- kiln_ai/adapters/parsers/base_parser.py +12 -0
- kiln_ai/adapters/parsers/json_parser.py +37 -0
- kiln_ai/adapters/parsers/parser_registry.py +19 -0
- kiln_ai/adapters/parsers/r1_parser.py +69 -0
- kiln_ai/adapters/parsers/test_json_parser.py +81 -0
- kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
- kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
- kiln_ai/adapters/prompt_builders.py +126 -20
- kiln_ai/adapters/provider_tools.py +91 -36
- kiln_ai/adapters/repair/repair_task.py +17 -6
- kiln_ai/adapters/repair/test_repair_task.py +4 -4
- kiln_ai/adapters/run_output.py +8 -0
- kiln_ai/adapters/test_adapter_registry.py +177 -0
- kiln_ai/adapters/test_generate_docs.py +69 -0
- kiln_ai/adapters/test_prompt_adaptors.py +8 -4
- kiln_ai/adapters/test_prompt_builders.py +190 -29
- kiln_ai/adapters/test_provider_tools.py +268 -46
- kiln_ai/datamodel/__init__.py +193 -12
- kiln_ai/datamodel/basemodel.py +31 -11
- kiln_ai/datamodel/json_schema.py +8 -3
- kiln_ai/datamodel/model_cache.py +8 -3
- kiln_ai/datamodel/test_basemodel.py +81 -2
- kiln_ai/datamodel/test_dataset_split.py +100 -3
- kiln_ai/datamodel/test_example_models.py +25 -4
- kiln_ai/datamodel/test_model_cache.py +24 -0
- kiln_ai/datamodel/test_model_perf.py +125 -0
- kiln_ai/datamodel/test_models.py +129 -0
- kiln_ai/utils/exhaustive_error.py +6 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/METADATA +9 -7
- kiln_ai-0.11.1.dist-info/RECORD +76 -0
- kiln_ai-0.8.1.dist-info/RECORD +0 -58
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/WHEEL +0 -0
- {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/licenses/LICENSE.txt +0 -0
kiln_ai/datamodel/__init__.py
CHANGED
|
@@ -60,6 +60,7 @@ __all__ = [
|
|
|
60
60
|
"TaskRequirement",
|
|
61
61
|
"strict_mode",
|
|
62
62
|
"set_strict_mode",
|
|
63
|
+
"Prompt",
|
|
63
64
|
]
|
|
64
65
|
|
|
65
66
|
|
|
@@ -274,12 +275,47 @@ class FineTuneStatusType(str, Enum):
|
|
|
274
275
|
failed = "failed"
|
|
275
276
|
|
|
276
277
|
|
|
278
|
+
class StructuredOutputMode(str, Enum):
|
|
279
|
+
"""
|
|
280
|
+
Enumeration of supported structured output modes.
|
|
281
|
+
|
|
282
|
+
- default: let the adapter decide
|
|
283
|
+
- json_schema: request json using API capabilities for json_schema
|
|
284
|
+
- function_calling: request json using API capabilities for function calling
|
|
285
|
+
- json_mode: request json using API's JSON mode, which should return valid JSON, but isn't checking/passing the schema
|
|
286
|
+
- json_instructions: append instructions to the prompt to request json matching the schema. No API capabilities are used. You should have a custom parser on these models as they will be returning strings.
|
|
287
|
+
- json_instruction_and_object: append instructions to the prompt to request json matching the schema. Also request the response as json_mode via API capabilities (returning dictionaries).
|
|
288
|
+
"""
|
|
289
|
+
|
|
290
|
+
default = "default"
|
|
291
|
+
json_schema = "json_schema"
|
|
292
|
+
function_calling = "function_calling"
|
|
293
|
+
json_mode = "json_mode"
|
|
294
|
+
json_instructions = "json_instructions"
|
|
295
|
+
json_instruction_and_object = "json_instruction_and_object"
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class FinetuneDataStrategy(str, Enum):
|
|
299
|
+
final_only = "final_only"
|
|
300
|
+
final_and_intermediate = "final_and_intermediate"
|
|
301
|
+
|
|
302
|
+
|
|
277
303
|
class Finetune(KilnParentedModel):
|
|
304
|
+
"""
|
|
305
|
+
The Kiln fine-tune datamodel.
|
|
306
|
+
|
|
307
|
+
Initially holds a reference to a training job, with needed identifiers to update the status. When complete, contains the new model ID.
|
|
308
|
+
"""
|
|
309
|
+
|
|
278
310
|
name: str = NAME_FIELD
|
|
279
311
|
description: str | None = Field(
|
|
280
312
|
default=None,
|
|
281
313
|
description="A description of the fine-tune for you and your team. Not used in training.",
|
|
282
314
|
)
|
|
315
|
+
structured_output_mode: StructuredOutputMode | None = Field(
|
|
316
|
+
default=None,
|
|
317
|
+
description="The mode to use to train the model for structured output, if it was trained with structured output. Will determine how we call the tuned model, so we call with the matching mode.",
|
|
318
|
+
)
|
|
283
319
|
provider: str = Field(
|
|
284
320
|
description="The provider to use for the fine-tune (e.g. 'openai')."
|
|
285
321
|
)
|
|
@@ -309,9 +345,14 @@ class Finetune(KilnParentedModel):
|
|
|
309
345
|
default={},
|
|
310
346
|
description="The parameters to use for this fine-tune. These are provider-specific.",
|
|
311
347
|
)
|
|
348
|
+
# These two fields are saved exactly used for training. Even if they map exactly to a custom prompt or generator, those can change, so we want to keep a record of the training prompt.
|
|
312
349
|
system_message: str = Field(
|
|
313
350
|
description="The system message to use for this fine-tune.",
|
|
314
351
|
)
|
|
352
|
+
thinking_instructions: str | None = Field(
|
|
353
|
+
default=None,
|
|
354
|
+
description="The thinking instructions to use for this fine-tune. Only used when data_strategy is final_and_intermediate.",
|
|
355
|
+
)
|
|
315
356
|
latest_status: FineTuneStatusType = Field(
|
|
316
357
|
default=FineTuneStatusType.unknown,
|
|
317
358
|
description="The latest known status of this fine-tune. Not updated in real time.",
|
|
@@ -320,12 +361,34 @@ class Finetune(KilnParentedModel):
|
|
|
320
361
|
default={},
|
|
321
362
|
description="Properties of the fine-tune. Different providers may use different properties.",
|
|
322
363
|
)
|
|
364
|
+
data_strategy: FinetuneDataStrategy = Field(
|
|
365
|
+
default=FinetuneDataStrategy.final_only,
|
|
366
|
+
description="The strategy to use for training the model. 'final_only' will only train on the final response. 'final_and_intermediate' will train on the final response and intermediate outputs (chain of thought or reasoning).",
|
|
367
|
+
)
|
|
323
368
|
|
|
324
369
|
def parent_task(self) -> Task | None:
|
|
325
370
|
if not isinstance(self.parent, Task):
|
|
326
371
|
return None
|
|
327
372
|
return self.parent
|
|
328
373
|
|
|
374
|
+
@model_validator(mode="after")
|
|
375
|
+
def validate_thinking_instructions(self) -> Self:
|
|
376
|
+
if (
|
|
377
|
+
self.thinking_instructions is not None
|
|
378
|
+
and self.data_strategy != FinetuneDataStrategy.final_and_intermediate
|
|
379
|
+
):
|
|
380
|
+
raise ValueError(
|
|
381
|
+
"Thinking instructions can only be used when data_strategy is final_and_intermediate"
|
|
382
|
+
)
|
|
383
|
+
if (
|
|
384
|
+
self.thinking_instructions is None
|
|
385
|
+
and self.data_strategy == FinetuneDataStrategy.final_and_intermediate
|
|
386
|
+
):
|
|
387
|
+
raise ValueError(
|
|
388
|
+
"Thinking instructions are required when data_strategy is final_and_intermediate"
|
|
389
|
+
)
|
|
390
|
+
return self
|
|
391
|
+
|
|
329
392
|
|
|
330
393
|
class DataSourceType(str, Enum):
|
|
331
394
|
"""
|
|
@@ -397,6 +460,13 @@ class DataSource(BaseModel):
|
|
|
397
460
|
type=str,
|
|
398
461
|
not_allowed_for=[DataSourceType.human],
|
|
399
462
|
),
|
|
463
|
+
DataSourceProperty(
|
|
464
|
+
# Optional: an ID within the scope of the prompt_builder_name.
|
|
465
|
+
# Used for prompt builders with IDs (like saved prompts, fine-tune prompts)
|
|
466
|
+
name="prompt_id",
|
|
467
|
+
type=str,
|
|
468
|
+
not_allowed_for=[DataSourceType.human],
|
|
469
|
+
),
|
|
400
470
|
]
|
|
401
471
|
|
|
402
472
|
@model_validator(mode="after")
|
|
@@ -470,13 +540,39 @@ class TaskRun(KilnParentedModel):
|
|
|
470
540
|
description="Tags for the task run. Tags are used to categorize task runs for filtering and reporting.",
|
|
471
541
|
)
|
|
472
542
|
|
|
543
|
+
def has_thinking_training_data(self) -> bool:
|
|
544
|
+
"""
|
|
545
|
+
Does this run have thinking data that we can use to train a thinking model?
|
|
546
|
+
"""
|
|
547
|
+
if self.intermediate_outputs is None:
|
|
548
|
+
return False
|
|
549
|
+
return (
|
|
550
|
+
"chain_of_thought" in self.intermediate_outputs
|
|
551
|
+
or "reasoning" in self.intermediate_outputs
|
|
552
|
+
)
|
|
553
|
+
|
|
473
554
|
def parent_task(self) -> Task | None:
|
|
474
555
|
if not isinstance(self.parent, Task):
|
|
475
556
|
return None
|
|
476
557
|
return self.parent
|
|
477
558
|
|
|
478
559
|
@model_validator(mode="after")
|
|
479
|
-
def validate_input_format(self) -> Self:
|
|
560
|
+
def validate_input_format(self, info: ValidationInfo) -> Self:
|
|
561
|
+
# Don't validate if loading from file (not new). Too slow.
|
|
562
|
+
# We don't allow changing task schema, so this is redundant validation.
|
|
563
|
+
# Note: we still validate if editing a loaded model
|
|
564
|
+
if self.loading_from_file(info):
|
|
565
|
+
# Consider loading an existing model as validated.
|
|
566
|
+
self._last_validated_input = self.input
|
|
567
|
+
return self
|
|
568
|
+
|
|
569
|
+
# Don't validate if input has not changed. Too slow to run this every time.
|
|
570
|
+
if (
|
|
571
|
+
hasattr(self, "_last_validated_input")
|
|
572
|
+
and self.input == self._last_validated_input
|
|
573
|
+
):
|
|
574
|
+
return self
|
|
575
|
+
|
|
480
576
|
task = self.parent_task()
|
|
481
577
|
if task is None:
|
|
482
578
|
# don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving)
|
|
@@ -490,15 +586,33 @@ class TaskRun(KilnParentedModel):
|
|
|
490
586
|
raise ValueError("Input is not a valid JSON object")
|
|
491
587
|
except jsonschema.exceptions.ValidationError as e:
|
|
492
588
|
raise ValueError(f"Input does not match task input schema: {e}")
|
|
589
|
+
self._last_validated_input = self.input
|
|
493
590
|
return self
|
|
494
591
|
|
|
495
592
|
@model_validator(mode="after")
|
|
496
|
-
def validate_output_format(self) -> Self:
|
|
593
|
+
def validate_output_format(self, info: ValidationInfo) -> Self:
|
|
594
|
+
# Don't validate if loading from file (not new). Too slow.
|
|
595
|
+
# Note: we still validate if editing a loaded model's output.
|
|
596
|
+
if self.loading_from_file(info):
|
|
597
|
+
# Consider loading an existing model as validated.
|
|
598
|
+
self._last_validated_output = self.output.output if self.output else None
|
|
599
|
+
return self
|
|
600
|
+
|
|
601
|
+
# Don't validate unless output has changed since last validation.
|
|
602
|
+
# The validator is slow and costly, don't want it running when setting other fields.
|
|
603
|
+
if (
|
|
604
|
+
hasattr(self, "_last_validated_output")
|
|
605
|
+
and self.output is not None
|
|
606
|
+
and self.output.output == self._last_validated_output
|
|
607
|
+
):
|
|
608
|
+
return self
|
|
609
|
+
|
|
497
610
|
task = self.parent_task()
|
|
498
611
|
if task is None:
|
|
499
612
|
return self
|
|
500
613
|
|
|
501
614
|
self.output.validate_output_format(task)
|
|
615
|
+
self._last_validated_output = self.output.output if self.output else None
|
|
502
616
|
return self
|
|
503
617
|
|
|
504
618
|
@model_validator(mode="after")
|
|
@@ -550,11 +664,47 @@ def AllDatasetFilter(_: TaskRun) -> bool:
|
|
|
550
664
|
|
|
551
665
|
|
|
552
666
|
def HighRatingDatasetFilter(task_run: TaskRun) -> bool:
|
|
553
|
-
if task_run.output is None
|
|
667
|
+
if task_run.output is None:
|
|
668
|
+
return False
|
|
669
|
+
if task_run.repaired_output is not None:
|
|
670
|
+
# Repairs always considered high quality
|
|
671
|
+
return True
|
|
672
|
+
if task_run.output.rating is None:
|
|
554
673
|
return False
|
|
555
674
|
return task_run.output.rating.is_high_quality()
|
|
556
675
|
|
|
557
676
|
|
|
677
|
+
def ThinkingModelDatasetFilter(task_run: TaskRun) -> bool:
|
|
678
|
+
"""
|
|
679
|
+
A filter that returns True if the task has intermediate outputs we can training a 'thinking' model on (reasoning or chain of thought)
|
|
680
|
+
"""
|
|
681
|
+
return task_run.has_thinking_training_data()
|
|
682
|
+
|
|
683
|
+
|
|
684
|
+
def ThinkingModelHighRatedFilter(task_run: TaskRun) -> bool:
|
|
685
|
+
"""
|
|
686
|
+
A filter that returns True if the task has thinking data and the output is high quality
|
|
687
|
+
"""
|
|
688
|
+
return ThinkingModelDatasetFilter(task_run) and HighRatingDatasetFilter(task_run)
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
class DatasetFilterType(str, Enum):
|
|
692
|
+
"""Dataset filter names."""
|
|
693
|
+
|
|
694
|
+
ALL = "all"
|
|
695
|
+
HIGH_RATING = "high_rating"
|
|
696
|
+
THINKING_MODEL = "thinking_model"
|
|
697
|
+
THINKING_MODEL_HIGH_RATED = "thinking_model_high_rated"
|
|
698
|
+
|
|
699
|
+
|
|
700
|
+
dataset_filters = {
|
|
701
|
+
DatasetFilterType.ALL: AllDatasetFilter,
|
|
702
|
+
DatasetFilterType.HIGH_RATING: HighRatingDatasetFilter,
|
|
703
|
+
DatasetFilterType.THINKING_MODEL: ThinkingModelDatasetFilter,
|
|
704
|
+
DatasetFilterType.THINKING_MODEL_HIGH_RATED: ThinkingModelHighRatedFilter,
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
|
|
558
708
|
class DatasetSplitDefinition(BaseModel):
|
|
559
709
|
"""
|
|
560
710
|
A definition of a split in a dataset.
|
|
@@ -586,6 +736,11 @@ Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [
|
|
|
586
736
|
DatasetSplitDefinition(name="test", percentage=0.2),
|
|
587
737
|
DatasetSplitDefinition(name="val", percentage=0.2),
|
|
588
738
|
]
|
|
739
|
+
Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [
|
|
740
|
+
DatasetSplitDefinition(name="train", percentage=0.8),
|
|
741
|
+
DatasetSplitDefinition(name="test", percentage=0.1),
|
|
742
|
+
DatasetSplitDefinition(name="val", percentage=0.1),
|
|
743
|
+
]
|
|
589
744
|
|
|
590
745
|
|
|
591
746
|
class DatasetSplit(KilnParentedModel):
|
|
@@ -609,6 +764,10 @@ class DatasetSplit(KilnParentedModel):
|
|
|
609
764
|
split_contents: dict[str, list[str]] = Field(
|
|
610
765
|
description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
|
|
611
766
|
)
|
|
767
|
+
filter: DatasetFilterType | None = Field(
|
|
768
|
+
default=None,
|
|
769
|
+
description="The filter used to build the dataset.",
|
|
770
|
+
)
|
|
612
771
|
|
|
613
772
|
@model_validator(mode="after")
|
|
614
773
|
def validate_split_percentages(self) -> "DatasetSplit":
|
|
@@ -623,12 +782,13 @@ class DatasetSplit(KilnParentedModel):
|
|
|
623
782
|
name: str,
|
|
624
783
|
task: "Task",
|
|
625
784
|
splits: list[DatasetSplitDefinition],
|
|
626
|
-
|
|
785
|
+
filter_type: DatasetFilterType = DatasetFilterType.ALL,
|
|
627
786
|
description: str | None = None,
|
|
628
787
|
):
|
|
629
788
|
"""
|
|
630
789
|
Build a dataset split from a task.
|
|
631
790
|
"""
|
|
791
|
+
filter = dataset_filters[filter_type]
|
|
632
792
|
split_contents = cls.build_split_contents(task, splits, filter)
|
|
633
793
|
return cls(
|
|
634
794
|
parent=task,
|
|
@@ -636,6 +796,7 @@ class DatasetSplit(KilnParentedModel):
|
|
|
636
796
|
description=description,
|
|
637
797
|
splits=splits,
|
|
638
798
|
split_contents=split_contents,
|
|
799
|
+
filter=filter_type,
|
|
639
800
|
)
|
|
640
801
|
|
|
641
802
|
@classmethod
|
|
@@ -686,7 +847,7 @@ class DatasetSplit(KilnParentedModel):
|
|
|
686
847
|
if parent is None:
|
|
687
848
|
raise ValueError("DatasetSplit has no parent task")
|
|
688
849
|
|
|
689
|
-
runs = parent.runs()
|
|
850
|
+
runs = parent.runs(readonly=True)
|
|
690
851
|
all_ids = set(run.id for run in runs)
|
|
691
852
|
all_ids_in_splits = set()
|
|
692
853
|
for ids in self.split_contents.values():
|
|
@@ -695,6 +856,22 @@ class DatasetSplit(KilnParentedModel):
|
|
|
695
856
|
return len(missing)
|
|
696
857
|
|
|
697
858
|
|
|
859
|
+
class Prompt(KilnParentedModel):
|
|
860
|
+
"""
|
|
861
|
+
A prompt for a task.
|
|
862
|
+
"""
|
|
863
|
+
|
|
864
|
+
name: str = NAME_FIELD
|
|
865
|
+
prompt: str = Field(
|
|
866
|
+
description="The prompt for the task.",
|
|
867
|
+
min_length=1,
|
|
868
|
+
)
|
|
869
|
+
chain_of_thought_instructions: str | None = Field(
|
|
870
|
+
default=None,
|
|
871
|
+
description="Instructions for the model 'thinking' about the requirement prior to answering. Used for chain of thought style prompting. COT will not be used unless this is provided.",
|
|
872
|
+
)
|
|
873
|
+
|
|
874
|
+
|
|
698
875
|
class TaskRequirement(BaseModel):
|
|
699
876
|
"""
|
|
700
877
|
Defines a specific requirement that should be met by task outputs.
|
|
@@ -732,6 +909,7 @@ class Task(
|
|
|
732
909
|
"runs": TaskRun,
|
|
733
910
|
"dataset_splits": DatasetSplit,
|
|
734
911
|
"finetunes": Finetune,
|
|
912
|
+
"prompts": Prompt,
|
|
735
913
|
},
|
|
736
914
|
):
|
|
737
915
|
"""
|
|
@@ -768,15 +946,18 @@ class Task(
|
|
|
768
946
|
return None
|
|
769
947
|
return schema_from_json_str(self.input_json_schema)
|
|
770
948
|
|
|
771
|
-
#
|
|
772
|
-
def runs(self) -> list[TaskRun]:
|
|
773
|
-
return super().runs() # type: ignore
|
|
949
|
+
# These wrappers help for typechecking. TODO P2: fix this in KilnParentModel
|
|
950
|
+
def runs(self, readonly: bool = False) -> list[TaskRun]:
|
|
951
|
+
return super().runs(readonly=readonly) # type: ignore
|
|
952
|
+
|
|
953
|
+
def dataset_splits(self, readonly: bool = False) -> list[DatasetSplit]:
|
|
954
|
+
return super().dataset_splits(readonly=readonly) # type: ignore
|
|
774
955
|
|
|
775
|
-
def
|
|
776
|
-
return super().
|
|
956
|
+
def finetunes(self, readonly: bool = False) -> list[Finetune]:
|
|
957
|
+
return super().finetunes(readonly=readonly) # type: ignore
|
|
777
958
|
|
|
778
|
-
def
|
|
779
|
-
return super().
|
|
959
|
+
def prompts(self, readonly: bool = False) -> list[Prompt]:
|
|
960
|
+
return super().prompts(readonly=readonly) # type: ignore
|
|
780
961
|
|
|
781
962
|
|
|
782
963
|
class Project(KilnParentModel, parent_of={"tasks": Task}):
|
kiln_ai/datamodel/basemodel.py
CHANGED
|
@@ -120,11 +120,12 @@ class KilnBaseModel(BaseModel):
|
|
|
120
120
|
return cls.load_from_file(path)
|
|
121
121
|
|
|
122
122
|
@classmethod
|
|
123
|
-
def load_from_file(cls: Type[T], path: Path | str) -> T:
|
|
123
|
+
def load_from_file(cls: Type[T], path: Path | str, readonly: bool = False) -> T:
|
|
124
124
|
"""Load a model instance from a specific file path.
|
|
125
125
|
|
|
126
126
|
Args:
|
|
127
127
|
path (Path): Path to the model file
|
|
128
|
+
readonly (bool): If True, the model will be returned in readonly mode (cached instance, not a copy, not safe to mutate)
|
|
128
129
|
|
|
129
130
|
Returns:
|
|
130
131
|
T: Instance of the model
|
|
@@ -135,10 +136,10 @@ class KilnBaseModel(BaseModel):
|
|
|
135
136
|
"""
|
|
136
137
|
if isinstance(path, str):
|
|
137
138
|
path = Path(path)
|
|
138
|
-
cached_model = ModelCache.shared().get_model(path, cls)
|
|
139
|
+
cached_model = ModelCache.shared().get_model(path, cls, readonly=readonly)
|
|
139
140
|
if cached_model is not None:
|
|
140
141
|
return cached_model
|
|
141
|
-
with open(path, "r") as file:
|
|
142
|
+
with open(path, "r", encoding="utf-8") as file:
|
|
142
143
|
# modified time of file for cache invalidation. From file descriptor so it's atomic w read.
|
|
143
144
|
mtime_ns = os.fstat(file.fileno()).st_mtime_ns
|
|
144
145
|
file_data = file.read()
|
|
@@ -168,13 +169,20 @@ class KilnBaseModel(BaseModel):
|
|
|
168
169
|
# Two methods of indicated it's loaded from file:
|
|
169
170
|
# 1) info.context.get("loading_from_file") -> During actual loading, before we can set _loaded_from_file
|
|
170
171
|
# 2) self._loaded_from_file -> After loading, set by the loader
|
|
172
|
+
if self.loading_from_file(info):
|
|
173
|
+
return True
|
|
174
|
+
return self._loaded_from_file
|
|
175
|
+
|
|
176
|
+
# indicates the model is currently being loaded from file (not mutating it after)
|
|
177
|
+
def loading_from_file(self, info: ValidationInfo | None = None) -> bool:
|
|
178
|
+
# info.context.get("loading_from_file") -> During actual loading, before we can set _loaded_from_file
|
|
171
179
|
if (
|
|
172
180
|
info is not None
|
|
173
181
|
and info.context is not None
|
|
174
182
|
and info.context.get("loading_from_file", False)
|
|
175
183
|
):
|
|
176
184
|
return True
|
|
177
|
-
return
|
|
185
|
+
return False
|
|
178
186
|
|
|
179
187
|
def save_to_file(self) -> None:
|
|
180
188
|
"""Save the model instance to a file.
|
|
@@ -190,7 +198,7 @@ class KilnBaseModel(BaseModel):
|
|
|
190
198
|
)
|
|
191
199
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
192
200
|
json_data = self.model_dump_json(indent=2, exclude={"path"})
|
|
193
|
-
with open(path, "w") as file:
|
|
201
|
+
with open(path, "w", encoding="utf-8") as file:
|
|
194
202
|
file.write(json_data)
|
|
195
203
|
# save the path so even if something like name changes, the file doesn't move
|
|
196
204
|
self.path = path
|
|
@@ -342,16 +350,28 @@ class KilnParentedModel(KilnBaseModel, metaclass=ABCMeta):
|
|
|
342
350
|
return []
|
|
343
351
|
|
|
344
352
|
# Collect all /relationship/{id}/{base_filename.kiln} files in the relationship folder
|
|
345
|
-
for
|
|
346
|
-
|
|
353
|
+
# manual code instead of glob for performance (5x speedup over glob)
|
|
354
|
+
|
|
355
|
+
base_filename = cls.base_filename()
|
|
356
|
+
# Iterate through immediate subdirectories using scandir for better performance
|
|
357
|
+
# Benchmark: scandir is 10x faster than glob, so worth the extra code
|
|
358
|
+
with os.scandir(relationship_folder) as entries:
|
|
359
|
+
for entry in entries:
|
|
360
|
+
if not entry.is_dir():
|
|
361
|
+
continue
|
|
362
|
+
|
|
363
|
+
child_file = Path(entry.path) / base_filename
|
|
364
|
+
if child_file.is_file():
|
|
365
|
+
yield child_file
|
|
347
366
|
|
|
348
367
|
@classmethod
|
|
349
368
|
def all_children_of_parent_path(
|
|
350
|
-
cls: Type[PT], parent_path: Path | None
|
|
369
|
+
cls: Type[PT], parent_path: Path | None, readonly: bool = False
|
|
351
370
|
) -> list[PT]:
|
|
352
371
|
children = []
|
|
353
372
|
for child_path in cls.iterate_children_paths_of_parent_path(parent_path):
|
|
354
|
-
|
|
373
|
+
item = cls.load_from_file(child_path, readonly=readonly)
|
|
374
|
+
children.append(item)
|
|
355
375
|
return children
|
|
356
376
|
|
|
357
377
|
@classmethod
|
|
@@ -394,8 +414,8 @@ class KilnParentModel(KilnBaseModel, metaclass=ABCMeta):
|
|
|
394
414
|
def _create_child_method(
|
|
395
415
|
cls, relationship_name: str, child_class: Type[KilnParentedModel]
|
|
396
416
|
):
|
|
397
|
-
def child_method(self) -> list[child_class]:
|
|
398
|
-
return child_class.all_children_of_parent_path(self.path)
|
|
417
|
+
def child_method(self, readonly: bool = False) -> list[child_class]:
|
|
418
|
+
return child_class.all_children_of_parent_path(self.path, readonly=readonly)
|
|
399
419
|
|
|
400
420
|
child_method.__name__ = relationship_name
|
|
401
421
|
child_method.__annotations__ = {"return": List[child_class]}
|
kiln_ai/datamodel/json_schema.py
CHANGED
|
@@ -42,9 +42,14 @@ def validate_schema(instance: Dict, schema_str: str) -> None:
|
|
|
42
42
|
jsonschema.exceptions.ValidationError: If validation fails
|
|
43
43
|
ValueError: If the schema is invalid
|
|
44
44
|
"""
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
45
|
+
try:
|
|
46
|
+
schema = schema_from_json_str(schema_str)
|
|
47
|
+
v = jsonschema.Draft202012Validator(schema)
|
|
48
|
+
v.validate(instance)
|
|
49
|
+
except jsonschema.exceptions.ValidationError as e:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
f"This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema. Search 'Troubleshooting Structured Data Issues' in our docs for more information. The error from the schema check was: {e.message}"
|
|
52
|
+
) from e
|
|
48
53
|
|
|
49
54
|
|
|
50
55
|
def schema_from_json_str(v: str) -> Dict:
|
kiln_ai/datamodel/model_cache.py
CHANGED
|
@@ -62,12 +62,17 @@ class ModelCache:
|
|
|
62
62
|
raise ValueError(f"Model at {path} is not of type {model_type.__name__}")
|
|
63
63
|
return model
|
|
64
64
|
|
|
65
|
-
def get_model(
|
|
66
|
-
|
|
65
|
+
def get_model(
|
|
66
|
+
self, path: Path, model_type: Type[T], readonly: bool = False
|
|
67
|
+
) -> Optional[T]:
|
|
68
|
+
# We return a copy by default, so in-memory edits don't impact the cache until they are saved
|
|
67
69
|
# Benchmark shows about 2x slower, but much more foolproof
|
|
68
70
|
model = self._get_model(path, model_type)
|
|
69
71
|
if model:
|
|
70
|
-
|
|
72
|
+
if readonly:
|
|
73
|
+
return model
|
|
74
|
+
else:
|
|
75
|
+
return model.model_copy(deep=True)
|
|
71
76
|
return None
|
|
72
77
|
|
|
73
78
|
def get_model_id(self, path: Path, model_type: Type[T]) -> Optional[str]:
|
|
@@ -6,6 +6,9 @@ from unittest.mock import MagicMock, patch
|
|
|
6
6
|
|
|
7
7
|
import pytest
|
|
8
8
|
|
|
9
|
+
from kiln_ai.adapters.model_adapters.base_adapter import AdapterInfo, BaseAdapter
|
|
10
|
+
from kiln_ai.adapters.run_output import RunOutput
|
|
11
|
+
from kiln_ai.datamodel import Task, TaskRun
|
|
9
12
|
from kiln_ai.datamodel.basemodel import (
|
|
10
13
|
KilnBaseModel,
|
|
11
14
|
KilnParentedModel,
|
|
@@ -356,7 +359,9 @@ def test_load_from_file_with_cache(test_base_file, tmp_model_cache):
|
|
|
356
359
|
model = KilnBaseModel.load_from_file(test_base_file)
|
|
357
360
|
|
|
358
361
|
# Check that the cache was checked and set
|
|
359
|
-
tmp_model_cache.get_model.assert_called_once_with(
|
|
362
|
+
tmp_model_cache.get_model.assert_called_once_with(
|
|
363
|
+
test_base_file, KilnBaseModel, readonly=False
|
|
364
|
+
)
|
|
360
365
|
tmp_model_cache.set_model.assert_called_once()
|
|
361
366
|
|
|
362
367
|
# Ensure the model is correctly loaded
|
|
@@ -407,7 +412,9 @@ def test_load_from_file_with_cached_model(test_base_file, tmp_model_cache):
|
|
|
407
412
|
model = KilnBaseModel.load_from_file(test_base_file)
|
|
408
413
|
|
|
409
414
|
# Check that the cache was checked and the cached model was returned
|
|
410
|
-
tmp_model_cache.get_model.assert_called_once_with(
|
|
415
|
+
tmp_model_cache.get_model.assert_called_once_with(
|
|
416
|
+
test_base_file, KilnBaseModel, readonly=False
|
|
417
|
+
)
|
|
411
418
|
assert model is cached_model
|
|
412
419
|
|
|
413
420
|
# Assert that open was not called (we used the cached model, not file)
|
|
@@ -469,3 +476,75 @@ def test_from_id_and_parent_path_without_parent():
|
|
|
469
476
|
# Test with None parent_path
|
|
470
477
|
not_found = DefaultParentedModel.from_id_and_parent_path("any-id", None)
|
|
471
478
|
assert not_found is None
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
class MockAdapter(BaseAdapter):
|
|
482
|
+
"""Implementation of BaseAdapter for testing"""
|
|
483
|
+
|
|
484
|
+
async def _run(self, input):
|
|
485
|
+
return RunOutput(output="test output", intermediate_outputs=None)
|
|
486
|
+
|
|
487
|
+
def adapter_info(self) -> AdapterInfo:
|
|
488
|
+
return AdapterInfo(
|
|
489
|
+
adapter_name="test",
|
|
490
|
+
model_name=self.model_name,
|
|
491
|
+
model_provider=self.model_provider_name,
|
|
492
|
+
prompt_builder_name="test",
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
@pytest.fixture
|
|
497
|
+
def base_task():
|
|
498
|
+
return Task(name="test_task", instruction="test_instruction")
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
@pytest.fixture
|
|
502
|
+
def adapter(base_task):
|
|
503
|
+
return MockAdapter(
|
|
504
|
+
kiln_task=base_task,
|
|
505
|
+
model_name="test_model",
|
|
506
|
+
model_provider_name="test_provider",
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
async def test_invoke_parsing_flow(adapter):
|
|
511
|
+
# Mock dependencies
|
|
512
|
+
mock_provider = MagicMock()
|
|
513
|
+
mock_provider.parser = "test_parser"
|
|
514
|
+
|
|
515
|
+
mock_parser = MagicMock()
|
|
516
|
+
mock_parser.parse_output.return_value = RunOutput(
|
|
517
|
+
output="parsed test output", intermediate_outputs={"key": "value"}
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
mock_parser_class = MagicMock(return_value=mock_parser)
|
|
521
|
+
|
|
522
|
+
with (
|
|
523
|
+
patch.object(adapter, "model_provider", return_value=mock_provider),
|
|
524
|
+
patch(
|
|
525
|
+
"kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id",
|
|
526
|
+
return_value=mock_parser_class,
|
|
527
|
+
),
|
|
528
|
+
patch("kiln_ai.adapters.model_adapters.base_adapter.Config") as mock_config,
|
|
529
|
+
):
|
|
530
|
+
# Disable autosaving for this test
|
|
531
|
+
mock_config.shared.return_value.autosave_runs = False
|
|
532
|
+
mock_config.shared.return_value.user_id = "test_user_id"
|
|
533
|
+
|
|
534
|
+
# Execute
|
|
535
|
+
result = await adapter.invoke("test input")
|
|
536
|
+
|
|
537
|
+
# Verify parser was created correctly
|
|
538
|
+
mock_parser_class.assert_called_once_with(structured_output=False)
|
|
539
|
+
|
|
540
|
+
# Verify parsing occurred
|
|
541
|
+
mock_parser.parse_output.assert_called_once()
|
|
542
|
+
parsed_args = mock_parser.parse_output.call_args[1]
|
|
543
|
+
assert isinstance(parsed_args["original_output"], RunOutput)
|
|
544
|
+
assert parsed_args["original_output"].output == "test output"
|
|
545
|
+
|
|
546
|
+
# Verify result contains parsed output
|
|
547
|
+
assert isinstance(result, TaskRun)
|
|
548
|
+
assert result.output.output == "parsed test output"
|
|
549
|
+
assert result.intermediate_outputs == {"key": "value"}
|
|
550
|
+
assert result.input == "test input"
|