kiln-ai 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (44) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +19 -0
  3. kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
  4. kiln_ai/adapters/fine_tune/__init__.py +14 -0
  5. kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
  6. kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
  7. kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
  8. kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
  9. kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
  10. kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
  11. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
  12. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
  13. kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
  14. kiln_ai/adapters/langchain_adapters.py +103 -13
  15. kiln_ai/adapters/ml_model_list.py +239 -303
  16. kiln_ai/adapters/ollama_tools.py +115 -0
  17. kiln_ai/adapters/provider_tools.py +308 -0
  18. kiln_ai/adapters/repair/repair_task.py +4 -2
  19. kiln_ai/adapters/repair/test_repair_task.py +6 -11
  20. kiln_ai/adapters/test_langchain_adapter.py +229 -18
  21. kiln_ai/adapters/test_ollama_tools.py +42 -0
  22. kiln_ai/adapters/test_prompt_adaptors.py +7 -5
  23. kiln_ai/adapters/test_provider_tools.py +531 -0
  24. kiln_ai/adapters/test_structured_output.py +22 -43
  25. kiln_ai/datamodel/__init__.py +287 -24
  26. kiln_ai/datamodel/basemodel.py +122 -38
  27. kiln_ai/datamodel/model_cache.py +116 -0
  28. kiln_ai/datamodel/registry.py +31 -0
  29. kiln_ai/datamodel/test_basemodel.py +167 -4
  30. kiln_ai/datamodel/test_dataset_split.py +234 -0
  31. kiln_ai/datamodel/test_example_models.py +12 -0
  32. kiln_ai/datamodel/test_model_cache.py +244 -0
  33. kiln_ai/datamodel/test_models.py +215 -1
  34. kiln_ai/datamodel/test_registry.py +96 -0
  35. kiln_ai/utils/config.py +14 -1
  36. kiln_ai/utils/name_generator.py +125 -0
  37. kiln_ai/utils/test_name_geneator.py +47 -0
  38. kiln_ai-0.7.1.dist-info/METADATA +237 -0
  39. kiln_ai-0.7.1.dist-info/RECORD +58 -0
  40. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/WHEEL +1 -1
  41. kiln_ai/adapters/test_ml_model_list.py +0 -181
  42. kiln_ai-0.6.1.dist-info/METADATA +0 -88
  43. kiln_ai-0.6.1.dist-info/RECORD +0 -37
  44. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,12 +1,23 @@
1
+ """
2
+ See our docs for details about our datamodel: https://kiln-ai.github.io/Kiln/kiln_core_docs/kiln_ai.html
3
+ """
4
+
1
5
  from __future__ import annotations
2
6
 
3
7
  import json
8
+ import math
9
+ import random
4
10
  from enum import Enum, IntEnum
5
- from typing import TYPE_CHECKING, Dict, List, Type, Union
11
+ from typing import TYPE_CHECKING, Callable, Dict, List, Type, Union
6
12
 
7
13
  import jsonschema
8
14
  import jsonschema.exceptions
9
- from pydantic import BaseModel, Field, model_validator
15
+ from pydantic import (
16
+ BaseModel,
17
+ Field,
18
+ ValidationInfo,
19
+ model_validator,
20
+ )
10
21
  from typing_extensions import Self
11
22
 
12
23
  from kiln_ai.datamodel.json_schema import JsonObjectSchema, schema_from_json_str
@@ -14,6 +25,8 @@ from kiln_ai.datamodel.json_schema import JsonObjectSchema, schema_from_json_str
14
25
  from .basemodel import (
15
26
  ID_FIELD,
16
27
  ID_TYPE,
28
+ NAME_FIELD,
29
+ SHORT_NAME_FIELD,
17
30
  KilnBaseModel,
18
31
  KilnParentedModel,
19
32
  KilnParentModel,
@@ -39,27 +52,23 @@ __all__ = [
39
52
  "TaskOutputRatingType",
40
53
  "TaskRequirement",
41
54
  "TaskDeterminism",
55
+ "strict_mode",
56
+ "set_strict_mode",
42
57
  ]
43
58
 
44
59
 
45
- # Conventions:
46
- # 1) Names are filename safe as they may be used as file names. They are informational and not to be used in prompts/training/validation.
47
- # 2) Descrptions are for Kiln users to describe/understanding the purpose of this object. They must never be used in prompts/training/validation. Use "instruction/requirements" instead.
60
+ # We want to be hard on ourselves for data completeness generated by the Kiln App, but don't want to make it hard for users to use the datamodel/library.
61
+ # Strict mode enables extra validations that we want to enforce in Kiln App (and any other client that wants best practices), but not in the library (unless they opt in)
62
+ _strict_mode: bool = False
48
63
 
49
- # Filename compatible names
50
- NAME_REGEX = r"^[A-Za-z0-9 _-]+$"
51
- NAME_FIELD = Field(
52
- min_length=1,
53
- max_length=120,
54
- pattern=NAME_REGEX,
55
- description="A name for this entity.",
56
- )
57
- SHORT_NAME_FIELD = Field(
58
- min_length=1,
59
- max_length=32,
60
- pattern=NAME_REGEX,
61
- description="A name for this entity",
62
- )
64
+
65
+ def strict_mode() -> bool:
66
+ return _strict_mode
67
+
68
+
69
+ def set_strict_mode(value: bool) -> None:
70
+ global _strict_mode
71
+ _strict_mode = value
63
72
 
64
73
 
65
74
  class Priority(IntEnum):
@@ -137,8 +146,9 @@ class TaskOutput(KilnBaseModel):
137
146
  output: str = Field(
138
147
  description="The output of the task. JSON formatted for structured output, plaintext for unstructured output."
139
148
  )
140
- source: DataSource = Field(
141
- description="The source of the output: human or synthetic."
149
+ source: DataSource | None = Field(
150
+ description="The source of the output: human or synthetic.",
151
+ default=None,
142
152
  )
143
153
  rating: TaskOutputRating | None = Field(
144
154
  default=None, description="The rating of the output"
@@ -155,6 +165,83 @@ class TaskOutput(KilnBaseModel):
155
165
  raise ValueError(f"Output does not match task output schema: {e}")
156
166
  return self
157
167
 
168
+ @model_validator(mode="after")
169
+ def validate_output_source(self, info: ValidationInfo) -> Self:
170
+ # On strict mode and not loaded from file, we validate output_source is not None.
171
+ # We want to be able to load any data, even if it's not perfect. But we want to create perfect data when adding new data.
172
+ if not strict_mode():
173
+ return self
174
+ if self.loaded_from_file(info):
175
+ return self
176
+ if self.source is None:
177
+ raise ValueError("Output source is required when strict mode is enabled")
178
+ return self
179
+
180
+
181
+ class FineTuneStatusType(str, Enum):
182
+ """
183
+ The status type of a fine-tune (running, completed, failed, etc).
184
+ """
185
+
186
+ unknown = "unknown" # server error
187
+ pending = "pending"
188
+ running = "running"
189
+ completed = "completed"
190
+ failed = "failed"
191
+
192
+
193
+ class Finetune(KilnParentedModel):
194
+ name: str = NAME_FIELD
195
+ description: str | None = Field(
196
+ default=None,
197
+ description="A description of the fine-tune for you and your team. Not used in training.",
198
+ )
199
+ provider: str = Field(
200
+ description="The provider to use for the fine-tune (e.g. 'openai')."
201
+ )
202
+ base_model_id: str = Field(
203
+ description="The id of the base model to use for the fine-tune. This string relates to the provider's IDs for their own models, not Kiln IDs."
204
+ )
205
+ provider_id: str | None = Field(
206
+ default=None,
207
+ description="The ID of the fine-tune job on the provider's side. May not be the same as the fine_tune_model_id.",
208
+ )
209
+ fine_tune_model_id: str | None = Field(
210
+ default=None,
211
+ description="The ID of the fine-tuned model on the provider's side. May not be the same as the provider_id.",
212
+ )
213
+ dataset_split_id: str = Field(
214
+ description="The ID of the dataset split to use for this fine-tune.",
215
+ )
216
+ train_split_name: str = Field(
217
+ default="train",
218
+ description="The name of the training split to use for this fine-tune.",
219
+ )
220
+ validation_split_name: str | None = Field(
221
+ default=None,
222
+ description="The name of the validation split to use for this fine-tune. Optional.",
223
+ )
224
+ parameters: dict[str, str | int | float | bool] = Field(
225
+ default={},
226
+ description="The parameters to use for this fine-tune. These are provider-specific.",
227
+ )
228
+ system_message: str = Field(
229
+ description="The system message to use for this fine-tune.",
230
+ )
231
+ latest_status: FineTuneStatusType = Field(
232
+ default=FineTuneStatusType.unknown,
233
+ description="The latest known status of this fine-tune. Not updated in real time.",
234
+ )
235
+ properties: Dict[str, str | int | float] = Field(
236
+ default={},
237
+ description="Properties of the fine-tune. Different providers may use different properties.",
238
+ )
239
+
240
+ def parent_task(self) -> Task | None:
241
+ if not isinstance(self.parent, Task):
242
+ return None
243
+ return self.parent
244
+
158
245
 
159
246
  class DataSourceType(str, Enum):
160
247
  """
@@ -277,8 +364,8 @@ class TaskRun(KilnParentedModel):
277
364
  input: str = Field(
278
365
  description="The inputs to the task. JSON formatted for structured input, plaintext for unstructured input."
279
366
  )
280
- input_source: DataSource = Field(
281
- description="The source of the input: human or synthetic."
367
+ input_source: DataSource | None = Field(
368
+ default=None, description="The source of the input: human or synthetic."
282
369
  )
283
370
 
284
371
  output: TaskOutput = Field(description="The output of the task run.")
@@ -343,6 +430,172 @@ class TaskRun(KilnParentedModel):
343
430
  )
344
431
  return self
345
432
 
433
+ @model_validator(mode="after")
434
+ def validate_input_source(self, info: ValidationInfo) -> Self:
435
+ # On strict mode and not loaded from file, we validate input_source is not None.
436
+ # We want to be able to load any data, even if it's not perfect. But we want to create perfect data when adding new data.
437
+ if not strict_mode():
438
+ return self
439
+ if self.loaded_from_file(info):
440
+ return self
441
+ if self.input_source is None:
442
+ raise ValueError("input_source is required when strict mode is enabled")
443
+ return self
444
+
445
+
446
+ # Define the type alias for clarity
447
+ DatasetFilter = Callable[[TaskRun], bool]
448
+
449
+
450
+ def AllDatasetFilter(_: TaskRun) -> bool:
451
+ return True
452
+
453
+
454
+ def HighRatingDatasetFilter(task_run: TaskRun) -> bool:
455
+ if task_run.output is None or task_run.output.rating is None:
456
+ return False
457
+ return task_run.output.rating.is_high_quality()
458
+
459
+
460
+ class DatasetSplitDefinition(BaseModel):
461
+ """
462
+ A definition of a split in a dataset.
463
+
464
+ Example: name="train", description="The training set", percentage=0.8 (80% of the dataset)
465
+ """
466
+
467
+ name: str = NAME_FIELD
468
+ description: str | None = Field(
469
+ default=None,
470
+ description="A description of the dataset for you and your team. Not used in training.",
471
+ )
472
+ percentage: float = Field(
473
+ ge=0.0,
474
+ le=1.0,
475
+ description="The percentage of the dataset that this split represents (between 0 and 1).",
476
+ )
477
+
478
+
479
+ AllSplitDefinition: list[DatasetSplitDefinition] = [
480
+ DatasetSplitDefinition(name="all", percentage=1.0)
481
+ ]
482
+ Train80Test20SplitDefinition: list[DatasetSplitDefinition] = [
483
+ DatasetSplitDefinition(name="train", percentage=0.8),
484
+ DatasetSplitDefinition(name="test", percentage=0.2),
485
+ ]
486
+ Train60Test20Val20SplitDefinition: list[DatasetSplitDefinition] = [
487
+ DatasetSplitDefinition(name="train", percentage=0.6),
488
+ DatasetSplitDefinition(name="test", percentage=0.2),
489
+ DatasetSplitDefinition(name="val", percentage=0.2),
490
+ ]
491
+
492
+
493
+ class DatasetSplit(KilnParentedModel):
494
+ """
495
+ A collection of task runs, with optional splits (train, test, validation).
496
+
497
+ Used to freeze a dataset into train/test/validation splits for repeatable fine-tuning or other tasks.
498
+
499
+ Maintains a list of IDs for each split, to avoid data duplication.
500
+ """
501
+
502
+ name: str = NAME_FIELD
503
+ description: str | None = Field(
504
+ default=None,
505
+ description="A description of the dataset for you and your team. Not used in training.",
506
+ )
507
+ splits: list[DatasetSplitDefinition] = Field(
508
+ default_factory=list,
509
+ description="The splits in the dataset.",
510
+ )
511
+ split_contents: dict[str, list[str]] = Field(
512
+ description="The contents of each split in the dataset. The key is the split name, and the value is a list of task run IDs.",
513
+ )
514
+
515
+ @model_validator(mode="after")
516
+ def validate_split_percentages(self) -> "DatasetSplit":
517
+ total = sum(split.percentage for split in self.splits)
518
+ if not math.isclose(total, 1.0, rel_tol=1e-9):
519
+ raise ValueError(f"The sum of split percentages must be 1.0 (got {total})")
520
+ return self
521
+
522
+ @classmethod
523
+ def from_task(
524
+ cls,
525
+ name: str,
526
+ task: "Task",
527
+ splits: list[DatasetSplitDefinition],
528
+ filter: DatasetFilter = AllDatasetFilter,
529
+ description: str | None = None,
530
+ ):
531
+ """
532
+ Build a dataset split from a task.
533
+ """
534
+ split_contents = cls.build_split_contents(task, splits, filter)
535
+ return cls(
536
+ parent=task,
537
+ name=name,
538
+ description=description,
539
+ splits=splits,
540
+ split_contents=split_contents,
541
+ )
542
+
543
+ @classmethod
544
+ def build_split_contents(
545
+ cls,
546
+ task: "Task",
547
+ splits: list[DatasetSplitDefinition],
548
+ filter: DatasetFilter,
549
+ ) -> dict[str, list[str]]:
550
+ valid_ids = []
551
+ for task_run in task.runs():
552
+ if filter(task_run):
553
+ valid_ids.append(task_run.id)
554
+
555
+ # Shuffle and split by split percentage
556
+ random.shuffle(valid_ids)
557
+ split_contents = {}
558
+ start_idx = 0
559
+ remaining_items = len(valid_ids)
560
+
561
+ # Handle all splits except the last one
562
+ for split in splits[:-1]:
563
+ split_size = round(len(valid_ids) * split.percentage)
564
+ split_contents[split.name] = valid_ids[start_idx : start_idx + split_size]
565
+ start_idx += split_size
566
+ remaining_items -= split_size
567
+
568
+ # Last split gets all remaining items (for rounding)
569
+ if splits:
570
+ split_contents[splits[-1].name] = valid_ids[start_idx:]
571
+
572
+ return split_contents
573
+
574
+ def parent_task(self) -> "Task | None":
575
+ # inline import to avoid circular import
576
+ from kiln_ai.datamodel import Task
577
+
578
+ if not isinstance(self.parent, Task):
579
+ return None
580
+ return self.parent
581
+
582
+ def missing_count(self) -> int:
583
+ """
584
+ Returns:
585
+ int: the number of task runs that have an ID persisted in this dataset split, but no longer exist in the dataset
586
+ """
587
+ parent = self.parent_task()
588
+ if parent is None:
589
+ raise ValueError("DatasetSplit has no parent task")
590
+
591
+ runs = parent.runs()
592
+ all_ids = set(run.id for run in runs)
593
+ all_ids_in_splits = set()
594
+ for ids in self.split_contents.values():
595
+ all_ids_in_splits.update(ids)
596
+ missing = all_ids_in_splits - all_ids
597
+ return len(missing)
598
+
346
599
 
347
600
  class TaskRequirement(BaseModel):
348
601
  """
@@ -376,7 +629,11 @@ class TaskDeterminism(str, Enum):
376
629
  class Task(
377
630
  KilnParentedModel,
378
631
  KilnParentModel,
379
- parent_of={"runs": TaskRun},
632
+ parent_of={
633
+ "runs": TaskRun,
634
+ "dataset_splits": DatasetSplit,
635
+ "finetunes": Finetune,
636
+ },
380
637
  ):
381
638
  """
382
639
  Represents a specific task to be performed, with associated requirements and validation rules.
@@ -416,6 +673,12 @@ class Task(
416
673
  def runs(self) -> list[TaskRun]:
417
674
  return super().runs() # type: ignore
418
675
 
676
+ def dataset_splits(self) -> list[DatasetSplit]:
677
+ return super().dataset_splits() # type: ignore
678
+
679
+ def finetunes(self) -> list[Finetune]:
680
+ return super().finetunes() # type: ignore
681
+
419
682
 
420
683
  class Project(KilnParentModel, parent_of={"tasks": Task}):
421
684
  """
@@ -1,4 +1,6 @@
1
1
  import json
2
+ import os
3
+ import re
2
4
  import shutil
3
5
  import uuid
4
6
  from abc import ABCMeta
@@ -6,7 +8,6 @@ from builtins import classmethod
6
8
  from datetime import datetime
7
9
  from pathlib import Path
8
10
  from typing import (
9
- TYPE_CHECKING,
10
11
  Any,
11
12
  Dict,
12
13
  List,
@@ -20,12 +21,14 @@ from pydantic import (
20
21
  ConfigDict,
21
22
  Field,
22
23
  ValidationError,
24
+ ValidationInfo,
23
25
  computed_field,
24
26
  model_validator,
25
27
  )
26
28
  from pydantic_core import ErrorDetails
27
29
  from typing_extensions import Self
28
30
 
31
+ from kiln_ai.datamodel.model_cache import ModelCache
29
32
  from kiln_ai.utils.config import Config
30
33
  from kiln_ai.utils.formatting import snake_case
31
34
 
@@ -39,6 +42,35 @@ T = TypeVar("T", bound="KilnBaseModel")
39
42
  PT = TypeVar("PT", bound="KilnParentedModel")
40
43
 
41
44
 
45
+ # Naming conventions:
46
+ # 1) Names are filename safe as they may be used as file names. They are informational and not to be used in prompts/training/validation.
47
+ # 2) Descrptions are for Kiln users to describe/understanding the purpose of this object. They must never be used in prompts/training/validation. Use "instruction/requirements" instead.
48
+
49
+ # Filename compatible names
50
+ NAME_REGEX = r"^[A-Za-z0-9 _-]+$"
51
+ NAME_FIELD = Field(
52
+ min_length=1,
53
+ max_length=120,
54
+ pattern=NAME_REGEX,
55
+ description="A name for this entity.",
56
+ )
57
+ SHORT_NAME_FIELD = Field(
58
+ min_length=1,
59
+ max_length=32,
60
+ pattern=NAME_REGEX,
61
+ description="A name for this entity",
62
+ )
63
+
64
+
65
+ def string_to_valid_name(name: str) -> str:
66
+ # Replace any character not allowed by NAME_REGEX with an underscore
67
+ valid_name = re.sub(r"[^A-Za-z0-9 _-]", "_", name)
68
+ # Replace consecutive underscores with a single underscore
69
+ valid_name = re.sub(r"_+", "_", valid_name)
70
+ # Remove leading and trailing underscores or whitespace
71
+ return valid_name.strip("_").strip()
72
+
73
+
42
74
  class KilnBaseModel(BaseModel):
43
75
  """Base model for all Kiln data models with common functionality for persistence and versioning.
44
76
 
@@ -58,6 +90,8 @@ class KilnBaseModel(BaseModel):
58
90
  created_at: datetime = Field(default_factory=datetime.now)
59
91
  created_by: str = Field(default_factory=lambda: Config.shared().user_id)
60
92
 
93
+ _loaded_from_file: bool = False
94
+
61
95
  @computed_field()
62
96
  def model_type(self) -> str:
63
97
  return self.type_name()
@@ -86,7 +120,7 @@ class KilnBaseModel(BaseModel):
86
120
  return cls.load_from_file(path)
87
121
 
88
122
  @classmethod
89
- def load_from_file(cls: Type[T], path: Path) -> T:
123
+ def load_from_file(cls: Type[T], path: Path | str) -> T:
90
124
  """Load a model instance from a specific file path.
91
125
 
92
126
  Args:
@@ -97,15 +131,28 @@ class KilnBaseModel(BaseModel):
97
131
 
98
132
  Raises:
99
133
  ValueError: If the loaded model is not of the expected type or version
134
+ FileNotFoundError: If the file does not exist
100
135
  """
136
+ if isinstance(path, str):
137
+ path = Path(path)
138
+ cached_model = ModelCache.shared().get_model(path, cls)
139
+ if cached_model is not None:
140
+ return cached_model
101
141
  with open(path, "r") as file:
142
+ # modified time of file for cache invalidation. From file descriptor so it's atomic w read.
143
+ mtime_ns = os.fstat(file.fileno()).st_mtime_ns
102
144
  file_data = file.read()
103
145
  # TODO P2 perf: parsing the JSON twice here.
104
146
  # Once for model_type, once for model. Can't call model_validate with parsed json because enum types break; they get strings instead of enums.
105
147
  parsed_json = json.loads(file_data)
106
- m = cls.model_validate_json(file_data, strict=True)
148
+ m = cls.model_validate_json(
149
+ file_data,
150
+ strict=True,
151
+ context={"loading_from_file": True},
152
+ )
107
153
  if not isinstance(m, cls):
108
154
  raise ValueError(f"Loaded model is not of type {cls.__name__}")
155
+ m._loaded_from_file = True
109
156
  file_data = None
110
157
  m.path = path
111
158
  if m.v > m.max_schema_version():
@@ -120,8 +167,21 @@ class KilnBaseModel(BaseModel):
120
167
  f"Class: {m.__class__.__name__}, id: {getattr(m, 'id', None)}, path: {path}, "
121
168
  f"version: {m.v}, max version: {m.max_schema_version()}"
122
169
  )
170
+ ModelCache.shared().set_model(path, m, mtime_ns)
123
171
  return m
124
172
 
173
+ def loaded_from_file(self, info: ValidationInfo | None = None) -> bool:
174
+ # Two methods of indicated it's loaded from file:
175
+ # 1) info.context.get("loading_from_file") -> During actual loading, before we can set _loaded_from_file
176
+ # 2) self._loaded_from_file -> After loading, set by the loader
177
+ if (
178
+ info is not None
179
+ and info.context is not None
180
+ and info.context.get("loading_from_file", False)
181
+ ):
182
+ return True
183
+ return self._loaded_from_file
184
+
125
185
  def save_to_file(self) -> None:
126
186
  """Save the model instance to a file.
127
187
 
@@ -140,6 +200,9 @@ class KilnBaseModel(BaseModel):
140
200
  file.write(json_data)
141
201
  # save the path so even if something like name changes, the file doesn't move
142
202
  self.path = path
203
+ # We could save, but invalidating will trigger load on next use.
204
+ # This ensures everything in cache is loaded from disk, and the cache perfectly reflects what's on disk
205
+ ModelCache.shared().invalidate(path)
143
206
 
144
207
  def delete(self) -> None:
145
208
  if self.path is None:
@@ -148,6 +211,7 @@ class KilnBaseModel(BaseModel):
148
211
  if dir_path is None:
149
212
  raise ValueError("Cannot delete model because path is not set")
150
213
  shutil.rmtree(dir_path)
214
+ ModelCache.shared().invalidate(self.path)
151
215
  self.path = None
152
216
 
153
217
  def build_path(self) -> Path | None:
@@ -167,51 +231,44 @@ class KilnParentedModel(KilnBaseModel, metaclass=ABCMeta):
167
231
  including parent reference handling and file system organization.
168
232
 
169
233
  Attributes:
170
- _parent (KilnBaseModel): Reference to the parent model instance
234
+ parent (KilnBaseModel): Reference to the parent model instance. Not persisted, just in memory.
171
235
  """
172
236
 
173
- _parent: KilnBaseModel | None = None
237
+ # Parent is an in memory only reference to parent. If it's set we use that. If not we'll try to load it from disk based on the path.
238
+ # We don't persist the parent reference to disk. See the accessors below for how we make it a clean api (parent accessor will lazy load from disk)
239
+ parent: Optional[KilnBaseModel] = Field(default=None, exclude=True)
174
240
 
175
- # workaround to tell typechecker that we support the parent property, even though it's not a stock property
176
- if TYPE_CHECKING:
177
- parent: KilnBaseModel # type: ignore
241
+ def __getattribute__(self, name: str) -> Any:
242
+ if name == "parent":
243
+ return self.load_parent()
244
+ return super().__getattribute__(name)
178
245
 
179
- def __init__(self, **data):
180
- super().__init__(**data)
181
- if "parent" in data:
182
- self.parent = data["parent"]
246
+ def cached_parent(self) -> Optional[KilnBaseModel]:
247
+ return object.__getattribute__(self, "parent")
183
248
 
184
- @property
185
- def parent(self) -> Optional[KilnBaseModel]:
249
+ def load_parent(self) -> Optional[KilnBaseModel]:
186
250
  """Get the parent model instance, loading it from disk if necessary.
187
251
 
188
252
  Returns:
189
253
  Optional[KilnBaseModel]: The parent model instance or None if not set
190
254
  """
191
- if self._parent is not None:
192
- return self._parent
255
+ cached_parent = self.cached_parent()
256
+ if cached_parent is not None:
257
+ return cached_parent
258
+
193
259
  # lazy load parent from path
194
260
  if self.path is None:
195
261
  return None
196
- # TODO: this only works with base_filename. If we every support custom names, we need to change this.
262
+ # Note: this only works with base_filename. If we every support custom names, we need to change this.
197
263
  parent_path = (
198
264
  self.path.parent.parent.parent
199
265
  / self.__class__.parent_type().base_filename()
200
266
  )
201
267
  if parent_path is None:
202
268
  return None
203
- self._parent = self.__class__.parent_type().load_from_file(parent_path)
204
- return self._parent
205
-
206
- @parent.setter
207
- def parent(self, value: Optional[KilnBaseModel]):
208
- if value is not None:
209
- expected_parent_type = self.__class__.parent_type()
210
- if not isinstance(value, expected_parent_type):
211
- raise ValueError(
212
- f"Parent must be of type {expected_parent_type}, but was {type(value)}"
213
- )
214
- self._parent = value
269
+ loaded_parent = self.__class__.parent_type().load_from_file(parent_path)
270
+ self.parent = loaded_parent
271
+ return loaded_parent
215
272
 
216
273
  # Dynamically implemented by KilnParentModel method injection
217
274
  @classmethod
@@ -225,11 +282,12 @@ class KilnParentedModel(KilnBaseModel, metaclass=ABCMeta):
225
282
 
226
283
  @model_validator(mode="after")
227
284
  def check_parent_type(self) -> Self:
228
- if self._parent is not None:
285
+ cached_parent = self.cached_parent()
286
+ if cached_parent is not None:
229
287
  expected_parent_type = self.__class__.parent_type()
230
- if not isinstance(self._parent, expected_parent_type):
288
+ if not isinstance(cached_parent, expected_parent_type):
231
289
  raise ValueError(
232
- f"Parent must be of type {expected_parent_type}, but was {type(self._parent)}"
290
+ f"Parent must be of type {expected_parent_type}, but was {type(cached_parent)}"
233
291
  )
234
292
  return self
235
293
 
@@ -268,9 +326,7 @@ class KilnParentedModel(KilnBaseModel, metaclass=ABCMeta):
268
326
  )
269
327
 
270
328
  @classmethod
271
- def all_children_of_parent_path(
272
- cls: Type[PT], parent_path: Path | None
273
- ) -> list[PT]:
329
+ def iterate_children_paths_of_parent_path(cls: Type[PT], parent_path: Path | None):
274
330
  if parent_path is None:
275
331
  # children are disk based. If not saved, they don't exist
276
332
  return []
@@ -292,13 +348,41 @@ class KilnParentedModel(KilnBaseModel, metaclass=ABCMeta):
292
348
  return []
293
349
 
294
350
  # Collect all /relationship/{id}/{base_filename.kiln} files in the relationship folder
295
- children = []
296
351
  for child_file in relationship_folder.glob(f"**/{cls.base_filename()}"):
297
- child = cls.load_from_file(child_file)
298
- children.append(child)
352
+ yield child_file
299
353
 
354
+ @classmethod
355
+ def all_children_of_parent_path(
356
+ cls: Type[PT], parent_path: Path | None
357
+ ) -> list[PT]:
358
+ children = []
359
+ for child_path in cls.iterate_children_paths_of_parent_path(parent_path):
360
+ children.append(cls.load_from_file(child_path))
300
361
  return children
301
362
 
363
+ @classmethod
364
+ def from_id_and_parent_path(
365
+ cls: Type[PT], id: str, parent_path: Path | None
366
+ ) -> PT | None:
367
+ """
368
+ Fast search by ID using the cache. Avoids the model_copy overhead on all but the exact match.
369
+
370
+ Uses cache so still slow on first load.
371
+ """
372
+ if parent_path is None:
373
+ return None
374
+
375
+ # Note: we're using the in-file ID. We could make this faster using the path-ID if this becomes perf bottleneck, but it's better to have 1 source of truth.
376
+ for child_path in cls.iterate_children_paths_of_parent_path(parent_path):
377
+ child_id = ModelCache.shared().get_model_id(child_path, cls)
378
+ if child_id == id:
379
+ return cls.load_from_file(child_path)
380
+ if child_id is None:
381
+ child = cls.load_from_file(child_path)
382
+ if child.id == id:
383
+ return child
384
+ return None
385
+
302
386
 
303
387
  # Parent create methods for all child relationships
304
388
  # You must pass in parent_of in the subclass definition, defining the child relationships