kiln-ai 0.8.1__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (57) hide show
  1. kiln_ai/adapters/__init__.py +7 -7
  2. kiln_ai/adapters/adapter_registry.py +77 -5
  3. kiln_ai/adapters/data_gen/data_gen_task.py +3 -3
  4. kiln_ai/adapters/data_gen/test_data_gen_task.py +23 -3
  5. kiln_ai/adapters/fine_tune/base_finetune.py +5 -1
  6. kiln_ai/adapters/fine_tune/dataset_formatter.py +310 -65
  7. kiln_ai/adapters/fine_tune/fireworks_finetune.py +47 -32
  8. kiln_ai/adapters/fine_tune/openai_finetune.py +12 -11
  9. kiln_ai/adapters/fine_tune/test_base_finetune.py +19 -0
  10. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +469 -129
  11. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +113 -21
  12. kiln_ai/adapters/fine_tune/test_openai_finetune.py +125 -14
  13. kiln_ai/adapters/ml_model_list.py +323 -94
  14. kiln_ai/adapters/model_adapters/__init__.py +18 -0
  15. kiln_ai/adapters/{base_adapter.py → model_adapters/base_adapter.py} +81 -37
  16. kiln_ai/adapters/{langchain_adapters.py → model_adapters/langchain_adapters.py} +130 -84
  17. kiln_ai/adapters/model_adapters/openai_compatible_config.py +11 -0
  18. kiln_ai/adapters/model_adapters/openai_model_adapter.py +246 -0
  19. kiln_ai/adapters/model_adapters/test_base_adapter.py +190 -0
  20. kiln_ai/adapters/{test_langchain_adapter.py → model_adapters/test_langchain_adapter.py} +103 -88
  21. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +225 -0
  22. kiln_ai/adapters/{test_saving_adapter_results.py → model_adapters/test_saving_adapter_results.py} +43 -15
  23. kiln_ai/adapters/{test_structured_output.py → model_adapters/test_structured_output.py} +93 -20
  24. kiln_ai/adapters/parsers/__init__.py +10 -0
  25. kiln_ai/adapters/parsers/base_parser.py +12 -0
  26. kiln_ai/adapters/parsers/json_parser.py +37 -0
  27. kiln_ai/adapters/parsers/parser_registry.py +19 -0
  28. kiln_ai/adapters/parsers/r1_parser.py +69 -0
  29. kiln_ai/adapters/parsers/test_json_parser.py +81 -0
  30. kiln_ai/adapters/parsers/test_parser_registry.py +32 -0
  31. kiln_ai/adapters/parsers/test_r1_parser.py +144 -0
  32. kiln_ai/adapters/prompt_builders.py +126 -20
  33. kiln_ai/adapters/provider_tools.py +91 -36
  34. kiln_ai/adapters/repair/repair_task.py +17 -6
  35. kiln_ai/adapters/repair/test_repair_task.py +4 -4
  36. kiln_ai/adapters/run_output.py +8 -0
  37. kiln_ai/adapters/test_adapter_registry.py +177 -0
  38. kiln_ai/adapters/test_generate_docs.py +69 -0
  39. kiln_ai/adapters/test_prompt_adaptors.py +8 -4
  40. kiln_ai/adapters/test_prompt_builders.py +190 -29
  41. kiln_ai/adapters/test_provider_tools.py +268 -46
  42. kiln_ai/datamodel/__init__.py +193 -12
  43. kiln_ai/datamodel/basemodel.py +31 -11
  44. kiln_ai/datamodel/json_schema.py +8 -3
  45. kiln_ai/datamodel/model_cache.py +8 -3
  46. kiln_ai/datamodel/test_basemodel.py +81 -2
  47. kiln_ai/datamodel/test_dataset_split.py +100 -3
  48. kiln_ai/datamodel/test_example_models.py +25 -4
  49. kiln_ai/datamodel/test_model_cache.py +24 -0
  50. kiln_ai/datamodel/test_model_perf.py +125 -0
  51. kiln_ai/datamodel/test_models.py +129 -0
  52. kiln_ai/utils/exhaustive_error.py +6 -0
  53. {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/METADATA +9 -7
  54. kiln_ai-0.11.1.dist-info/RECORD +76 -0
  55. kiln_ai-0.8.1.dist-info/RECORD +0 -58
  56. {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/WHEEL +0 -0
  57. {kiln_ai-0.8.1.dist-info → kiln_ai-0.11.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -60,6 +60,7 @@ __all__ = [
60
60
  "TaskRequirement",
61
61
  "strict_mode",
62
62
  "set_strict_mode",
63
+ "Prompt",
63
64
  ]
64
65
 
65
66
 
@@ -274,12 +275,47 @@ class FineTuneStatusType(str, Enum):
274
275
  failed = "failed"
275
276
 
276
277
 
278
+ class StructuredOutputMode(str, Enum):
279
+ """
280
+ Enumeration of supported structured output modes.
281
+
282
+ - default: let the adapter decide
283
+ - json_schema: request json using API capabilities for json_schema
284
+ - function_calling: request json using API capabilities for function calling
285
+ - json_mode: request json using API's JSON mode, which should return valid JSON, but isn't checking/passing the schema
286
+ - json_instructions: append instructions to the prompt to request json matching the schema. No API capabilities are used. You should have a custom parser on these models as they will be returning strings.
287
+ - json_instruction_and_object: append instructions to the prompt to request json matching the schema. Also request the response as json_mode via API capabilities (returning dictionaries).
288
+ """
289
+
290
+ default = "default"
291
+ json_schema = "json_schema"
292
+ function_calling = "function_calling"
293
+ json_mode = "json_mode"
294
+ json_instructions = "json_instructions"
295
+ json_instruction_and_object = "json_instruction_and_object"
296
+
297
+
298
+ class FinetuneDataStrategy(str, Enum):
299
+ final_only = "final_only"
300
+ final_and_intermediate = "final_and_intermediate"
301
+
302
+
277
303
  class Finetune(KilnParentedModel):
304
+ """
305
+ The Kiln fine-tune datamodel.
306
+
307
+ Initially holds a reference to a training job, with needed identifiers to update the status. When complete, contains the new model ID.
308
+ """
309
+
278
310
  name: str = NAME_FIELD
279
311
  description: str | None = Field(
280
312
  default=None,
281
313
  description="A description of the fine-tune for you and your team. Not used in training.",
282
314
  )
315
+ structured_output_mode: StructuredOutputMode | None = Field(
316
+ default=None,
317
+ description="The mode to use to train the model for structured output, if it was trained with structured output. Will determine how we call the tuned model, so we call with the matching mode.",
318
+ )
283
319
  provider: str = Field(
284
320
  description="The provider to use for the fine-tune (e.g. 'openai')."
285
321
  )
@@ -309,9 +345,14 @@ class Finetune(KilnParentedModel):
309
345
  default={},
310
346
  description="The parameters to use for this fine-tune. These are provider-specific.",
311
347
  )
348
+ # These two fields are saved exactly used for training. Even if they map exactly to a custom prompt or generator, those can change, so we want to keep a record of the training prompt.
312
349
  system_message: str = Field(
313
350
  description="The system message to use for this fine-tune.",
314
351
  )
352
+ thinking_instructions: str | None = Field(
353
+ default=None,
354
+ description="The thinking instructions to use for this fine-tune. Only used when data_strategy is final_and_intermediate.",
355
+ )
315
356
  latest_status: FineTuneStatusType = Field(
316
357
  default=FineTuneStatusType.unknown,
317
358
  description="The latest known status of this fine-tune. Not updated in real time.",
@@ -320,12 +361,34 @@ class Finetune(KilnParentedModel):
320
361
  default={},
321
362
  description="Properties of the fine-tune. Different providers may use different properties.",
322
363
  )
364
+ data_strategy: FinetuneDataStrategy = Field(
365
+ default=FinetuneDataStrategy.final_only,
366
+ description="The strategy to use for training the model. 'final_only' will only train on the final response. 'final_and_intermediate' will train on the final response and intermediate outputs (chain of thought or reasoning).",
367
+ )
323
368
 
324
369
  def parent_task(self) -> Task | None:
325
370
  if not isinstance(self.parent, Task):
326
371
  return None
327
372
  return self.parent
328
373
 
374
+ @model_validator(mode="after")
375
+ def validate_thinking_instructions(self) -> Self:
376
+ if (
377
+ self.thinking_instructions is not None
378
+ and self.data_strategy != FinetuneDataStrategy.final_and_intermediate
379
+ ):
380
+ raise ValueError(
381
+ "Thinking instructions can only be used when data_strategy is final_and_intermediate"
382
+ )
383
+ if (
384
+ self.thinking_instructions is None
385
+ and self.data_strategy == FinetuneDataStrategy.final_and_intermediate
386
+ ):
387
+ raise ValueError(
388
+ "Thinking instructions are required when data_strategy is final_and_intermediate"
389
+ )
390
+ return self
391
+
329
392
 
330
393
  class DataSourceType(str, Enum):
331
394
  """
@@ -397,6 +460,13 @@ class DataSource(BaseModel):
397
460
  type=str,
398
461
  not_allowed_for=[DataSourceType.human],
399
462
  ),
463
+ DataSourceProperty(
464
+ # Optional: an ID within the scope of the prompt_builder_name.
465
+ # Used for prompt builders with IDs (like saved prompts, fine-tune prompts)
466
+ name="prompt_id",
467
+ type=str,
468
+ not_allowed_for=[DataSourceType.human],
469
+ ),
400
470
  ]
401
471
 
402
472
  @model_validator(mode="after")
@@ -470,13 +540,39 @@ class TaskRun(KilnParentedModel):
470
540
  description="Tags for the task run. Tags are used to categorize task runs for filtering and reporting.",
471
541
  )
472
542
 
543
+ def has_thinking_training_data(self) -> bool:
544
+ """
545
+ Does this run have thinking data that we can use to train a thinking model?
546
+ """
547
+ if self.intermediate_outputs is None:
548
+ return False
549
+ return (
550
+ "chain_of_thought" in self.intermediate_outputs
551
+ or "reasoning" in self.intermediate_outputs
552
+ )
553
+
473
554
  def parent_task(self) -> Task | None:
474
555
  if not isinstance(self.parent, Task):
475
556
  return None
476
557
  return self.parent
477
558
 
478
559
  @model_validator(mode="after")
479
- def validate_input_format(self) -> Self:
560
+ def validate_input_format(self, info: ValidationInfo) -> Self:
561
+ # Don't validate if loading from file (not new). Too slow.
562
+ # We don't allow changing task schema, so this is redundant validation.
563
+ # Note: we still validate if editing a loaded model
564
+ if self.loading_from_file(info):
565
+ # Consider loading an existing model as validated.
566
+ self._last_validated_input = self.input
567
+ return self
568
+
569
+ # Don't validate if input has not changed. Too slow to run this every time.
570
+ if (
571
+ hasattr(self, "_last_validated_input")
572
+ and self.input == self._last_validated_input
573
+ ):
574
+ return self
575
+
480
576
  task = self.parent_task()
481
577
  if task is None:
482
578
  # don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving)
@@ -490,15 +586,33 @@ class TaskRun(KilnParentedModel):
490
586
  raise ValueError("Input is not a valid JSON object")
491
587
  except jsonschema.exceptions.ValidationError as e:
492
588
  raise ValueError(f"Input does not match task input schema: {e}")
589
+ self._last_validated_input = self.input
493
590
  return self
494
591
 
495
592
  @model_validator(mode="after")
496
- def validate_output_format(self) -> Self:
593
+ def validate_output_format(self, info: ValidationInfo) -> Self:
594
+ # Don't validate if loading from file (not new). Too slow.
595
+ # Note: we still validate if editing a loaded model's output.
596
+ if self.loading_from_file(info):
597
+ # Consider loading an existing model as validated.
598
+ self._last_validated_output = self.output.output if self.output else None
599
+ return self
600
+
601
+ # Don't validate unless output has changed since last validation.
602
+ # The validator is slow and costly, don't want it running when setting other fields.
603
+ if (
604
+ hasattr(self, "_last_validated_output")
605
+ and self.output is not None
606
+ and self.output.output == self._last_validated_output
607
+ ):
608
+ return self
609
+
497
610
  task = self.parent_task()
498
611
  if task is None:
499
612
  return self
500
613
 
501
614
  self.output.validate_output_format(task)
615
+ self._last_validated_output = self.output.output if self.output else None
502
616
  return self
503
617
 
504
618
  @model_validator(mode="after")
@@ -550,11 +664,47 @@ def AllDatasetFilter(_: TaskRun) -> bool:
550
664
 
551
665
 
552
666
  def HighRatingDatasetFilter(task_run: TaskRun) -> bool:
553
- if task_run.output is None or task_run.output.rating is None:
667
+ if task_run.output is None:
668
+ return False
669
+ if task_run.repaired_output is not None:
670
+ # Repairs always considered high quality
671
+ return True
672
+ if task_run.output.rating is None:
554
673
  return False
555
674
  return task_run.output.rating.is_high_quality()
556
675
 
557
676
 
677
+ def ThinkingModelDatasetFilter(task_run: TaskRun) -> bool:
678
+ """
679
+ A filter that returns True if the task has intermediate outputs we can training a 'thinking' model on (reasoning or chain of thought)
680
+ """
681
+ return task_run.has_thinking_training_data()
682
+
683
+
684
+ def ThinkingModelHighRatedFilter(task_run: TaskRun) -> bool:
685
+ """
686
+ A filter that returns True if the task has thinking data and the output is high quality
687
+ """
688
+ return ThinkingModelDatasetFilter(task_run) and HighRatingDatasetFilter(task_run)
689
+
690
+
691
+ class DatasetFilterType(str, Enum):
692
+ """Dataset filter names."""
693
+
694
+ ALL = "all"
695
+ HIGH_RATING = "high_rating"
696
+ THINKING_MODEL = "thinking_model"
697
+ THINKING_MODEL_HIGH_RATED = "thinking_model_high_rated"
698
+
699
+
700
+ dataset_filters = {
701
+ DatasetFilterType.ALL: AllDatasetFilter,
702
+ DatasetFilterType.HIGH_RATING: HighRatingDatasetFilter,
703
+ DatasetFilterType.THINKING_MODEL: ThinkingModelDatasetFilter,
704
+ DatasetFilterType.THINKING_MODEL_HIGH_RATED: ThinkingModelHighRatedFilter,
705
+ }
706
+
707
+
558
708
  class DatasetSplitDefinition(BaseModel):
559
709
  """
560
710
  A definition of a split in a dataset.
@@ -586,6 +736,11 @@ Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [
586
736
  DatasetSplitDefinition(name="test", percentage=0.2),
587
737
  DatasetSplitDefinition(name="val", percentage=0.2),
588
738
  ]
739
+ Train80Test10Val10SplitDefinition: list[DatasetSplitDefinition] = [
740
+ DatasetSplitDefinition(name="train", percentage=0.8),
741
+ DatasetSplitDefinition(name="test", percentage=0.1),
742
+ DatasetSplitDefinition(name="val", percentage=0.1),
743
+ ]
589
744
 
590
745
 
591
746
  class DatasetSplit(KilnParentedModel):
@@ -609,6 +764,10 @@ class DatasetSplit(KilnParentedModel):
609
764
  split_contents: dict[str, list[str]] = Field(
610
765
  description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
611
766
  )
767
+ filter: DatasetFilterType | None = Field(
768
+ default=None,
769
+ description="The filter used to build the dataset.",
770
+ )
612
771
 
613
772
  @model_validator(mode="after")
614
773
  def validate_split_percentages(self) -> "DatasetSplit":
@@ -623,12 +782,13 @@ class DatasetSplit(KilnParentedModel):
623
782
  name: str,
624
783
  task: "Task",
625
784
  splits: list[DatasetSplitDefinition],
626
- filter: DatasetFilter = AllDatasetFilter,
785
+ filter_type: DatasetFilterType = DatasetFilterType.ALL,
627
786
  description: str | None = None,
628
787
  ):
629
788
  """
630
789
  Build a dataset split from a task.
631
790
  """
791
+ filter = dataset_filters[filter_type]
632
792
  split_contents = cls.build_split_contents(task, splits, filter)
633
793
  return cls(
634
794
  parent=task,
@@ -636,6 +796,7 @@ class DatasetSplit(KilnParentedModel):
636
796
  description=description,
637
797
  splits=splits,
638
798
  split_contents=split_contents,
799
+ filter=filter_type,
639
800
  )
640
801
 
641
802
  @classmethod
@@ -686,7 +847,7 @@ class DatasetSplit(KilnParentedModel):
686
847
  if parent is None:
687
848
  raise ValueError("DatasetSplit has no parent task")
688
849
 
689
- runs = parent.runs()
850
+ runs = parent.runs(readonly=True)
690
851
  all_ids = set(run.id for run in runs)
691
852
  all_ids_in_splits = set()
692
853
  for ids in self.split_contents.values():
@@ -695,6 +856,22 @@ class DatasetSplit(KilnParentedModel):
695
856
  return len(missing)
696
857
 
697
858
 
859
+ class Prompt(KilnParentedModel):
860
+ """
861
+ A prompt for a task.
862
+ """
863
+
864
+ name: str = NAME_FIELD
865
+ prompt: str = Field(
866
+ description="The prompt for the task.",
867
+ min_length=1,
868
+ )
869
+ chain_of_thought_instructions: str | None = Field(
870
+ default=None,
871
+ description="Instructions for the model 'thinking' about the requirement prior to answering. Used for chain of thought style prompting. COT will not be used unless this is provided.",
872
+ )
873
+
874
+
698
875
  class TaskRequirement(BaseModel):
699
876
  """
700
877
  Defines a specific requirement that should be met by task outputs.
@@ -732,6 +909,7 @@ class Task(
732
909
  "runs": TaskRun,
733
910
  "dataset_splits": DatasetSplit,
734
911
  "finetunes": Finetune,
912
+ "prompts": Prompt,
735
913
  },
736
914
  ):
737
915
  """
@@ -768,15 +946,18 @@ class Task(
768
946
  return None
769
947
  return schema_from_json_str(self.input_json_schema)
770
948
 
771
- # Needed for typechecking. TODO P2: fix this in KilnParentModel
772
- def runs(self) -> list[TaskRun]:
773
- return super().runs() # type: ignore
949
+ # These wrappers help for typechecking. TODO P2: fix this in KilnParentModel
950
+ def runs(self, readonly: bool = False) -> list[TaskRun]:
951
+ return super().runs(readonly=readonly) # type: ignore
952
+
953
+ def dataset_splits(self, readonly: bool = False) -> list[DatasetSplit]:
954
+ return super().dataset_splits(readonly=readonly) # type: ignore
774
955
 
775
- def dataset_splits(self) -> list[DatasetSplit]:
776
- return super().dataset_splits() # type: ignore
956
+ def finetunes(self, readonly: bool = False) -> list[Finetune]:
957
+ return super().finetunes(readonly=readonly) # type: ignore
777
958
 
778
- def finetunes(self) -> list[Finetune]:
779
- return super().finetunes() # type: ignore
959
+ def prompts(self, readonly: bool = False) -> list[Prompt]:
960
+ return super().prompts(readonly=readonly) # type: ignore
780
961
 
781
962
 
782
963
  class Project(KilnParentModel, parent_of={"tasks": Task}):
@@ -120,11 +120,12 @@ class KilnBaseModel(BaseModel):
120
120
  return cls.load_from_file(path)
121
121
 
122
122
  @classmethod
123
- def load_from_file(cls: Type[T], path: Path | str) -> T:
123
+ def load_from_file(cls: Type[T], path: Path | str, readonly: bool = False) -> T:
124
124
  """Load a model instance from a specific file path.
125
125
 
126
126
  Args:
127
127
  path (Path): Path to the model file
128
+ readonly (bool): If True, the model will be returned in readonly mode (cached instance, not a copy, not safe to mutate)
128
129
 
129
130
  Returns:
130
131
  T: Instance of the model
@@ -135,10 +136,10 @@ class KilnBaseModel(BaseModel):
135
136
  """
136
137
  if isinstance(path, str):
137
138
  path = Path(path)
138
- cached_model = ModelCache.shared().get_model(path, cls)
139
+ cached_model = ModelCache.shared().get_model(path, cls, readonly=readonly)
139
140
  if cached_model is not None:
140
141
  return cached_model
141
- with open(path, "r") as file:
142
+ with open(path, "r", encoding="utf-8") as file:
142
143
  # modified time of file for cache invalidation. From file descriptor so it's atomic w read.
143
144
  mtime_ns = os.fstat(file.fileno()).st_mtime_ns
144
145
  file_data = file.read()
@@ -168,13 +169,20 @@ class KilnBaseModel(BaseModel):
168
169
  # Two methods of indicated it's loaded from file:
169
170
  # 1) info.context.get("loading_from_file") -> During actual loading, before we can set _loaded_from_file
170
171
  # 2) self._loaded_from_file -> After loading, set by the loader
172
+ if self.loading_from_file(info):
173
+ return True
174
+ return self._loaded_from_file
175
+
176
+ # indicates the model is currently being loaded from file (not mutating it after)
177
+ def loading_from_file(self, info: ValidationInfo | None = None) -> bool:
178
+ # info.context.get("loading_from_file") -> During actual loading, before we can set _loaded_from_file
171
179
  if (
172
180
  info is not None
173
181
  and info.context is not None
174
182
  and info.context.get("loading_from_file", False)
175
183
  ):
176
184
  return True
177
- return self._loaded_from_file
185
+ return False
178
186
 
179
187
  def save_to_file(self) -> None:
180
188
  """Save the model instance to a file.
@@ -190,7 +198,7 @@ class KilnBaseModel(BaseModel):
190
198
  )
191
199
  path.parent.mkdir(parents=True, exist_ok=True)
192
200
  json_data = self.model_dump_json(indent=2, exclude={"path"})
193
- with open(path, "w") as file:
201
+ with open(path, "w", encoding="utf-8") as file:
194
202
  file.write(json_data)
195
203
  # save the path so even if something like name changes, the file doesn't move
196
204
  self.path = path
@@ -342,16 +350,28 @@ class KilnParentedModel(KilnBaseModel, metaclass=ABCMeta):
342
350
  return []
343
351
 
344
352
  # Collect all /relationship/{id}/{base_filename.kiln} files in the relationship folder
345
- for child_file in relationship_folder.glob(f"**/{cls.base_filename()}"):
346
- yield child_file
353
+ # manual code instead of glob for performance (5x speedup over glob)
354
+
355
+ base_filename = cls.base_filename()
356
+ # Iterate through immediate subdirectories using scandir for better performance
357
+ # Benchmark: scandir is 10x faster than glob, so worth the extra code
358
+ with os.scandir(relationship_folder) as entries:
359
+ for entry in entries:
360
+ if not entry.is_dir():
361
+ continue
362
+
363
+ child_file = Path(entry.path) / base_filename
364
+ if child_file.is_file():
365
+ yield child_file
347
366
 
348
367
  @classmethod
349
368
  def all_children_of_parent_path(
350
- cls: Type[PT], parent_path: Path | None
369
+ cls: Type[PT], parent_path: Path | None, readonly: bool = False
351
370
  ) -> list[PT]:
352
371
  children = []
353
372
  for child_path in cls.iterate_children_paths_of_parent_path(parent_path):
354
- children.append(cls.load_from_file(child_path))
373
+ item = cls.load_from_file(child_path, readonly=readonly)
374
+ children.append(item)
355
375
  return children
356
376
 
357
377
  @classmethod
@@ -394,8 +414,8 @@ class KilnParentModel(KilnBaseModel, metaclass=ABCMeta):
394
414
  def _create_child_method(
395
415
  cls, relationship_name: str, child_class: Type[KilnParentedModel]
396
416
  ):
397
- def child_method(self) -> list[child_class]:
398
- return child_class.all_children_of_parent_path(self.path)
417
+ def child_method(self, readonly: bool = False) -> list[child_class]:
418
+ return child_class.all_children_of_parent_path(self.path, readonly=readonly)
399
419
 
400
420
  child_method.__name__ = relationship_name
401
421
  child_method.__annotations__ = {"return": List[child_class]}
@@ -42,9 +42,14 @@ def validate_schema(instance: Dict, schema_str: str) -> None:
42
42
  jsonschema.exceptions.ValidationError: If validation fails
43
43
  ValueError: If the schema is invalid
44
44
  """
45
- schema = schema_from_json_str(schema_str)
46
- v = jsonschema.Draft202012Validator(schema)
47
- return v.validate(instance)
45
+ try:
46
+ schema = schema_from_json_str(schema_str)
47
+ v = jsonschema.Draft202012Validator(schema)
48
+ v.validate(instance)
49
+ except jsonschema.exceptions.ValidationError as e:
50
+ raise ValueError(
51
+ f"This task requires a specific output schema. While the model produced JSON, that JSON didn't meet the schema. Search 'Troubleshooting Structured Data Issues' in our docs for more information. The error from the schema check was: {e.message}"
52
+ ) from e
48
53
 
49
54
 
50
55
  def schema_from_json_str(v: str) -> Dict:
@@ -62,12 +62,17 @@ class ModelCache:
62
62
  raise ValueError(f"Model at {path} is not of type {model_type.__name__}")
63
63
  return model
64
64
 
65
- def get_model(self, path: Path, model_type: Type[T]) -> Optional[T]:
66
- # We return a copy so in-memory edits don't impact the cache until they are saved
65
+ def get_model(
66
+ self, path: Path, model_type: Type[T], readonly: bool = False
67
+ ) -> Optional[T]:
68
+ # We return a copy by default, so in-memory edits don't impact the cache until they are saved
67
69
  # Benchmark shows about 2x slower, but much more foolproof
68
70
  model = self._get_model(path, model_type)
69
71
  if model:
70
- return model.model_copy(deep=True)
72
+ if readonly:
73
+ return model
74
+ else:
75
+ return model.model_copy(deep=True)
71
76
  return None
72
77
 
73
78
  def get_model_id(self, path: Path, model_type: Type[T]) -> Optional[str]:
@@ -6,6 +6,9 @@ from unittest.mock import MagicMock, patch
6
6
 
7
7
  import pytest
8
8
 
9
+ from kiln_ai.adapters.model_adapters.base_adapter import AdapterInfo, BaseAdapter
10
+ from kiln_ai.adapters.run_output import RunOutput
11
+ from kiln_ai.datamodel import Task, TaskRun
9
12
  from kiln_ai.datamodel.basemodel import (
10
13
  KilnBaseModel,
11
14
  KilnParentedModel,
@@ -356,7 +359,9 @@ def test_load_from_file_with_cache(test_base_file, tmp_model_cache):
356
359
  model = KilnBaseModel.load_from_file(test_base_file)
357
360
 
358
361
  # Check that the cache was checked and set
359
- tmp_model_cache.get_model.assert_called_once_with(test_base_file, KilnBaseModel)
362
+ tmp_model_cache.get_model.assert_called_once_with(
363
+ test_base_file, KilnBaseModel, readonly=False
364
+ )
360
365
  tmp_model_cache.set_model.assert_called_once()
361
366
 
362
367
  # Ensure the model is correctly loaded
@@ -407,7 +412,9 @@ def test_load_from_file_with_cached_model(test_base_file, tmp_model_cache):
407
412
  model = KilnBaseModel.load_from_file(test_base_file)
408
413
 
409
414
  # Check that the cache was checked and the cached model was returned
410
- tmp_model_cache.get_model.assert_called_once_with(test_base_file, KilnBaseModel)
415
+ tmp_model_cache.get_model.assert_called_once_with(
416
+ test_base_file, KilnBaseModel, readonly=False
417
+ )
411
418
  assert model is cached_model
412
419
 
413
420
  # Assert that open was not called (we used the cached model, not file)
@@ -469,3 +476,75 @@ def test_from_id_and_parent_path_without_parent():
469
476
  # Test with None parent_path
470
477
  not_found = DefaultParentedModel.from_id_and_parent_path("any-id", None)
471
478
  assert not_found is None
479
+
480
+
481
+ class MockAdapter(BaseAdapter):
482
+ """Implementation of BaseAdapter for testing"""
483
+
484
+ async def _run(self, input):
485
+ return RunOutput(output="test output", intermediate_outputs=None)
486
+
487
+ def adapter_info(self) -> AdapterInfo:
488
+ return AdapterInfo(
489
+ adapter_name="test",
490
+ model_name=self.model_name,
491
+ model_provider=self.model_provider_name,
492
+ prompt_builder_name="test",
493
+ )
494
+
495
+
496
+ @pytest.fixture
497
+ def base_task():
498
+ return Task(name="test_task", instruction="test_instruction")
499
+
500
+
501
+ @pytest.fixture
502
+ def adapter(base_task):
503
+ return MockAdapter(
504
+ kiln_task=base_task,
505
+ model_name="test_model",
506
+ model_provider_name="test_provider",
507
+ )
508
+
509
+
510
+ async def test_invoke_parsing_flow(adapter):
511
+ # Mock dependencies
512
+ mock_provider = MagicMock()
513
+ mock_provider.parser = "test_parser"
514
+
515
+ mock_parser = MagicMock()
516
+ mock_parser.parse_output.return_value = RunOutput(
517
+ output="parsed test output", intermediate_outputs={"key": "value"}
518
+ )
519
+
520
+ mock_parser_class = MagicMock(return_value=mock_parser)
521
+
522
+ with (
523
+ patch.object(adapter, "model_provider", return_value=mock_provider),
524
+ patch(
525
+ "kiln_ai.adapters.model_adapters.base_adapter.model_parser_from_id",
526
+ return_value=mock_parser_class,
527
+ ),
528
+ patch("kiln_ai.adapters.model_adapters.base_adapter.Config") as mock_config,
529
+ ):
530
+ # Disable autosaving for this test
531
+ mock_config.shared.return_value.autosave_runs = False
532
+ mock_config.shared.return_value.user_id = "test_user_id"
533
+
534
+ # Execute
535
+ result = await adapter.invoke("test input")
536
+
537
+ # Verify parser was created correctly
538
+ mock_parser_class.assert_called_once_with(structured_output=False)
539
+
540
+ # Verify parsing occurred
541
+ mock_parser.parse_output.assert_called_once()
542
+ parsed_args = mock_parser.parse_output.call_args[1]
543
+ assert isinstance(parsed_args["original_output"], RunOutput)
544
+ assert parsed_args["original_output"].output == "test output"
545
+
546
+ # Verify result contains parsed output
547
+ assert isinstance(result, TaskRun)
548
+ assert result.output.output == "parsed test output"
549
+ assert result.intermediate_outputs == {"key": "value"}
550
+ assert result.input == "test input"